private boolean hasConverged () { for (Iterator it = rg.edgeIterator (); it.hasNext();) { RegionEdge edge = (RegionEdge) it.next (); Factor oldMsg = oldMessages.getMessage (edge.from, edge.to); Factor newMsg = newMessages.getMessage (edge.from, edge.to); if (oldMsg == null) { assert newMsg == null; } else { if (!oldMsg.almostEquals (newMsg, THRESHOLD)) { /* //xxx debug if (sender instanceof SparseMessageSender) System.out.println ("NOT CONVERGED:\n"+newMsg+"\n......."); */ return false; } } } return true; }
private void compareMarginals (String msg, Factor[] pre, Factor[] post) { for (int i = 0; i < pre.length; i++) { Factor ptl1 = pre[i]; Factor ptl2 = post[i]; assertTrue (msg + "\n" + ptl1.dumpToString () + "\n" + ptl2.dumpToString (), ptl1.almostEquals (ptl2, 1e-3)); } }
public void testSlice () { Assignment assn = new Assignment (alpha, 1.0); Factor sliced = factor.slice (assn); assertTrue (sliced instanceof AbstractTableFactor); assertTrue (sliced.varSet ().equals (vars)); TableFactor expected = new TableFactor (vars, new double[] { 1.0, Math.exp(-1), Math.exp(-1), 1.0 }); assertTrue (sliced.almostEquals (expected)); }
private void compareMarginals (String msg, Factor[] pre, Factor[] post) { for (int i = 0; i < pre.length; i++) { Factor ptl1 = pre[i]; Factor ptl2 = post[i]; assertTrue (msg + "\n" + ptl1.dumpToString () + "\n" + ptl2.dumpToString (), ptl1.almostEquals (ptl2, 1e-3)); } }
public void testSlice () { Assignment assn = new Assignment (alpha, 1.0); Factor sliced = factor.slice (assn); assertTrue (sliced instanceof AbstractTableFactor); assertTrue (sliced.varSet ().equals (vars)); TableFactor expected = new TableFactor (vars, new double[] { 1.0, Math.exp(-1), Math.exp(-1), 1.0 }); assertTrue (sliced.almostEquals (expected)); }
public void testLogNormalize () { FactorGraph mdl = models [0]; Iterator it = mdl.variablesIterator (); Variable v1 = (Variable) it.next(); Variable v2 = (Variable) it.next(); Random rand = new Random (3214123); for (int i = 0; i < 10; i++) { Factor ptl = randomEdgePotential (rand, v1, v2); Factor norm1 = new LogTableFactor((AbstractTableFactor) ptl); Factor norm2 = ptl.duplicate(); norm1.normalize(); norm2.normalize(); assertTrue ("LogNormalize failed! Correct: "+norm2+" Log-normed: "+norm1, norm1.almostEquals (norm2)); } }
public void testLogNormalize () { FactorGraph mdl = models [0]; Iterator it = mdl.variablesIterator (); Variable v1 = (Variable) it.next(); Variable v2 = (Variable) it.next(); Random rand = new Random (3214123); for (int i = 0; i < 10; i++) { Factor ptl = randomEdgePotential (rand, v1, v2); Factor norm1 = new LogTableFactor((AbstractTableFactor) ptl); Factor norm2 = ptl.duplicate(); norm1.normalize(); norm2.normalize(); assertTrue ("LogNormalize failed! Correct: "+norm2+" Log-normed: "+norm1, norm1.almostEquals (norm2)); } }
public void testLogMarginalize () { FactorGraph mdl = models [0]; Iterator it = mdl.variablesIterator (); Variable v1 = (Variable) it.next(); Variable v2 = (Variable) it.next(); Random rand = new Random (3214123); for (int i = 0; i < 10; i++) { Factor ptl = randomEdgePotential (rand, v1, v2); Factor logmarg1 = new LogTableFactor ((AbstractTableFactor) ptl).marginalize(v1); Factor marglog1 = new LogTableFactor((AbstractTableFactor) ptl.marginalize(v1)); assertTrue ("LogMarg failed! Correct: "+marglog1+" Log-marg: "+logmarg1, logmarg1.almostEquals (marglog1)); Factor logmarg2 = new LogTableFactor ((AbstractTableFactor) ptl).marginalize(v2); Factor marglog2 = new LogTableFactor((AbstractTableFactor) ptl.marginalize(v2)); assertTrue (logmarg2.almostEquals (marglog2)); } }
public void testLogMarginalize () { FactorGraph mdl = models [0]; Iterator it = mdl.variablesIterator (); Variable v1 = (Variable) it.next(); Variable v2 = (Variable) it.next(); Random rand = new Random (3214123); for (int i = 0; i < 10; i++) { Factor ptl = randomEdgePotential (rand, v1, v2); Factor logmarg1 = new LogTableFactor ((AbstractTableFactor) ptl).marginalize(v1); Factor marglog1 = new LogTableFactor((AbstractTableFactor) ptl.marginalize(v1)); assertTrue ("LogMarg failed! Correct: "+marglog1+" Log-marg: "+logmarg1, logmarg1.almostEquals (marglog1)); Factor logmarg2 = new LogTableFactor ((AbstractTableFactor) ptl).marginalize(v2); Factor marglog2 = new LogTableFactor((AbstractTableFactor) ptl.marginalize(v2)); assertTrue (logmarg2.almostEquals (marglog2)); } }
public void testTrpViterbiEquiv2() { for (int mdlIdx = 0; mdlIdx < trees.length; mdlIdx++) { FactorGraph mdl = trees[mdlIdx]; Inferencer maxprod = TreeBP.createForMaxProduct (); TRP trp = TRP.createForMaxProduct (); maxprod.computeMarginals (mdl); trp.computeMarginals (mdl); // TRP should return same results as viterbi for (Iterator it = mdl.variablesIterator (); it.hasNext ();) { Variable var = (Variable) it.next (); Factor maxPotBp = maxprod.lookupMarginal (var); Factor maxPotTrp = trp.lookupMarginal (var); assertTrue ("TRP maxprod propagation not the same as plain maxProd!\n" + "Trp " + maxPotTrp + "\n Plain maxprod " + maxPotBp, maxPotBp.almostEquals (maxPotTrp)); } } }
public void testTrpViterbiEquiv2() { for (int mdlIdx = 0; mdlIdx < trees.length; mdlIdx++) { FactorGraph mdl = trees[mdlIdx]; Inferencer maxprod = TreeBP.createForMaxProduct (); TRP trp = TRP.createForMaxProduct (); maxprod.computeMarginals (mdl); trp.computeMarginals (mdl); // TRP should return same results as viterbi for (Iterator it = mdl.variablesIterator (); it.hasNext ();) { Variable var = (Variable) it.next (); Factor maxPotBp = maxprod.lookupMarginal (var); Factor maxPotTrp = trp.lookupMarginal (var); assertTrue ("TRP maxprod propagation not the same as plain maxProd!\n" + "Trp " + maxPotTrp + "\n Plain maxprod " + maxPotBp, maxPotBp.almostEquals (maxPotTrp)); } } }
public void testMarginalize () { FactorGraph fg = new FactorGraph (); fg.multiplyBy (tbl1); fg.multiplyBy (tbl2); Factor marg = fg.marginalize (vars[1]); Factor expected = new TableFactor (vars[1], new double[] { 0.81, 0.9 }); assertTrue (expected.almostEquals (marg)); }
public void testMarginalize () { FactorGraph fg = new FactorGraph (); fg.multiplyBy (tbl1); fg.multiplyBy (tbl2); Factor marg = fg.marginalize (vars[1]); Factor expected = new TableFactor (vars[1], new double[] { 0.81, 0.9 }); assertTrue (expected.almostEquals (marg)); }
private void compareMarginals (String msg, FactorGraph fg, Inferencer inf1, Inferencer inf2) { for (int i = 0; i < fg.numVariables (); i++) { Variable var = fg.get (i); Factor ptl1 = inf1.lookupMarginal (var); Factor ptl2 = inf2.lookupMarginal (var); assertTrue (msg + "\n" + ptl1.dumpToString () + "\n" + ptl2.dumpToString (), ptl1.almostEquals (ptl2, 1e-5)); } }
/** Tests that running TRP doesn't inadvertantly change potentials in the original graph. */ public void testTrpNonDestructivity() { FactorGraph model = createTriangle(); TRP trp = new TRP(new TRP.IterationTerminator(25)); BruteForceInferencer brute = new BruteForceInferencer(); Factor joint1 = brute.joint(model); trp.computeMarginals(model); Factor joint2 = brute.joint(model); assertTrue(joint1.almostEquals(joint2)); logger.info("Test trpNonDestructivity passed."); }
private void compareMarginals (String msg, FactorGraph fg, Inferencer inf1, Inferencer inf2) { for (int i = 0; i < fg.numVariables (); i++) { Variable var = fg.get (i); Factor ptl1 = inf1.lookupMarginal (var); Factor ptl2 = inf2.lookupMarginal (var); assertTrue (msg + "\n" + ptl1.dumpToString () + "\n" + ptl2.dumpToString (), ptl1.almostEquals (ptl2, 1e-5)); } }
public void testEarlyStopping () { FactorGraph grid = RandomGraphs.randomAttractiveGrid (5, 0.5, new Random (2413421)); TRP trp = new TRP (new TRP.IterationTerminator (1)); trp.setRandomSeed (14312341); trp.computeMarginals (grid); boolean oneIsDifferent = false; // check no exceptions thrown when asking for all marginals, // and check that at least one factors' belief has changed // from the choice at zero iterations. for (Iterator it = grid.factorsIterator (); it.hasNext();) { Factor f = (Factor) it.next (); Factor marg = trp.lookupMarginal (f.varSet ());// test no exception thrown if (!marg.almostEquals (f.duplicate ().normalize ())) { oneIsDifferent = true; } } assertTrue (oneIsDifferent); }
public void testEarlyStopping () { FactorGraph grid = RandomGraphs.randomAttractiveGrid (5, 0.5, new Random (2413421)); TRP trp = new TRP (new TRP.IterationTerminator (1)); trp.setRandomSeed (14312341); trp.computeMarginals (grid); boolean oneIsDifferent = false; // check no exceptions thrown when asking for all marginals, // and check that at least one factors' belief has changed // from the choice at zero iterations. for (Iterator it = grid.factorsIterator (); it.hasNext();) { Factor f = (Factor) it.next (); Factor marg = trp.lookupMarginal (f.varSet ());// test no exception thrown if (!marg.almostEquals (f.duplicate ().normalize ())) { oneIsDifferent = true; } } assertTrue (oneIsDifferent); }
/** Tests that running TRP doesn't inadvertantly change potentials in the original graph. */ public void testTrpNonDestructivity() { FactorGraph model = createTriangle(); TRP trp = new TRP(new TRP.IterationTerminator(25)); BruteForceInferencer brute = new BruteForceInferencer(); Factor joint1 = brute.joint(model); trp.computeMarginals(model); Factor joint2 = brute.joint(model); assertTrue(joint1.almostEquals(joint2)); logger.info("Test trpNonDestructivity passed."); }
public void testLoopyCaching () { FactorGraph mdl1 = models[4]; FactorGraph mdl2 = models[5]; Variable var = mdl1.get (0); LoopyBP inferencer = new LoopyBP (); inferencer.setUseCaching (true); inferencer.computeMarginals (mdl1); Factor origPtl = inferencer.lookupMarginal (var); assertTrue (2 < inferencer.iterationsUsed ()); // confuse the inferencer inferencer.computeMarginals (mdl2); // make sure we have cached, correct results inferencer.computeMarginals (mdl1); Factor sndPtl = inferencer.lookupMarginal (var); // note that we can't use an epsilon here, that's less than our convergence criteria. assertTrue ("Huh? Original potential:"+origPtl+"After: "+sndPtl, origPtl.almostEquals (sndPtl, 1e-4)); assertEquals (1, inferencer.iterationsUsed ()); }