public void testSumLogProb () { java.util.Random rand = new java.util.Random (3214123); for (int i = 0; i < 10; i++) { double v1 = rand.nextDouble(); double v2 = rand.nextDouble(); double sum1 = Math.log (v1 + v2); double sum2 = Maths.sumLogProb (Math.log(v1), Math.log (v2)); // System.out.println("Summing "+v1+" + "+v2); assertEquals (sum1, sum2, 0.00001); } }
public void testSumLogProb () { java.util.Random rand = new java.util.Random (3214123); for (int i = 0; i < 10; i++) { double v1 = rand.nextDouble(); double v2 = rand.nextDouble(); double sum1 = Math.log (v1 + v2); double sum2 = Maths.sumLogProb (Math.log(v1), Math.log (v2)); // System.out.println("Summing "+v1+" + "+v2); assertEquals (sum1, sum2, 0.00001); } }
public void testDestructiveAssignment () { Variable vars[] = { new Variable(2), new Variable (2), }; Assignment assn = new Assignment (vars, new int[] { 0, 1 }); assertEquals (0, assn.get (vars[0])); assertEquals (1, assn.get (vars[1])); assn.setValue (vars[0], 1); assertEquals (1, assn.get (vars[0])); assertEquals (1, assn.get (vars[1])); }
public void testDestructiveAssignment () { Variable vars[] = { new Variable(2), new Variable (2), }; Assignment assn = new Assignment (vars, new int[] { 0, 1 }); assertEquals (0, assn.get (vars[0])); assertEquals (1, assn.get (vars[1])); assn.setValue (vars[0], 1); assertEquals (1, assn.get (vars[0])); assertEquals (1, assn.get (vars[1])); }
public void testJunctionTreeConnectedFromRoot () { JunctionTreeInferencer jti = new JunctionTreeInferencer (); jti.computeMarginals (models[0]); jti.computeMarginals (models[1]); JunctionTree jt = jti.lookupJunctionTree (); List reached = new ArrayList (); LinkedList queue = new LinkedList (); queue.add (jt.getRoot ()); while (!queue.isEmpty ()) { VarSet current = (VarSet) queue.removeFirst (); queue.addAll (jt.getChildren (current)); reached.add (current); } assertEquals (jt.clusterPotentials ().size (), reached.size()); }
public void testJunctionTreeConnectedFromRoot () { JunctionTreeInferencer jti = new JunctionTreeInferencer (); jti.computeMarginals (models[0]); jti.computeMarginals (models[1]); JunctionTree jt = jti.lookupJunctionTree (); List reached = new ArrayList (); LinkedList queue = new LinkedList (); queue.add (jt.getRoot ()); while (!queue.isEmpty ()) { VarSet current = (VarSet) queue.removeFirst (); queue.addAll (jt.getChildren (current)); reached.add (current); } assertEquals (jt.clusterPotentials ().size (), reached.size()); }
public void testSingletonGraph () { Variable v = new Variable (2); FactorGraph mdl = new FactorGraph (new Variable[] { v }); mdl.addFactor (new TableFactor (v, new double[] { 1, 2 })); TRP trp = new TRP (); trp.computeMarginals (mdl); Factor ptl = trp.lookupMarginal (v); double[] dbl = ((AbstractTableFactor) ptl).toValueArray (); assertEquals (2, dbl.length); assertEquals (0.33333, dbl[0], 1e-4); assertEquals (0.66666, dbl[1], 1e-4); }
public void testSingletonGraph () { Variable v = new Variable (2); FactorGraph mdl = new FactorGraph (new Variable[] { v }); mdl.addFactor (new TableFactor (v, new double[] { 1, 2 })); TRP trp = new TRP (); trp.computeMarginals (mdl); Factor ptl = trp.lookupMarginal (v); double[] dbl = ((AbstractTableFactor) ptl).toValueArray (); assertEquals (2, dbl.length); assertEquals (0.33333, dbl[0], 1e-4); assertEquals (0.66666, dbl[1], 1e-4); }
public void testJointConsistent () throws Exception { for (int i = 0; i < allAlgs.length; i++) { // for (int mdlIdx = 0; mdlIdx < models.length; mdlIdx++) { { int mdlIdx = 13; Inferencer inf = (Inferencer) allAlgs[i].newInstance(); try { FactorGraph mdl = models[mdlIdx]; inf.computeMarginals(mdl); Assignment assn = new Assignment (mdl, new int [mdl.numVariables ()]); assertEquals (Math.log (inf.lookupJoint (assn)), inf.lookupLogJoint (assn), 1e-5); } catch (UnsupportedOperationException e) { // LoopyBP only handles edge ptls logger.warning("Skipping (" + mdlIdx + "," + i + ")\n" + e); throw e; // continue; } } } }
public void testJointConsistent () throws Exception { for (int i = 0; i < allAlgs.length; i++) { // for (int mdlIdx = 0; mdlIdx < models.length; mdlIdx++) { { int mdlIdx = 13; Inferencer inf = (Inferencer) allAlgs[i].newInstance(); try { FactorGraph mdl = models[mdlIdx]; inf.computeMarginals(mdl); Assignment assn = new Assignment (mdl, new int [mdl.numVariables ()]); assertEquals (Math.log (inf.lookupJoint (assn)), inf.lookupLogJoint (assn), 1e-5); } catch (UnsupportedOperationException e) { // LoopyBP only handles edge ptls logger.warning("Skipping (" + mdlIdx + "," + i + ")\n" + e); throw e; // continue; } } } }
public void testUniformJoint () throws Exception { FactorGraph mdl = RandomGraphs.createUniformChain (3); double expected = -Math.log (8); for (int i = 0; i < allAlgs.length; i++) { Inferencer inf = (Inferencer) allAlgs[i].newInstance (); inf.computeMarginals (mdl); for (AssignmentIterator it = mdl.assignmentIterator (); it.hasNext ();) { Assignment assn = it.assignment (); double actual = inf.lookupLogJoint (assn); assertEquals ("Incorrect joint for inferencer "+inf, expected, actual, 1e-5); it.advance (); } } }
public void testQuery () throws Exception { java.util.Random rand = new java.util.Random (15667); for (int mdlIdx = 0; mdlIdx < models.length; mdlIdx++) { FactorGraph mdl = models [mdlIdx]; int size = rand.nextInt (3) + 2; size = Math.min (size, mdl.varSet ().size ()); Collection vars = CollectionUtils.subset (mdl.variablesSet (), size, rand); Variable[] varArr = (Variable[]) vars.toArray (new Variable [0]); Assignment assn = new Assignment (varArr, new int [size]); BruteForceInferencer brute = new BruteForceInferencer(); Factor joint = brute.joint(mdl); double marginal = joint.marginalize(vars).value (assn); for (int algIdx = 0; algIdx < appxAlgs.length; algIdx++) { Inferencer alg = (Inferencer) appxAlgs[algIdx].newInstance(); if (alg instanceof TRP) continue; // trp can't handle disconnected models, which arise during query() double returned = alg.query (mdl, assn); assertEquals ("Failure on model "+mdlIdx+" alg "+alg, marginal, returned, APPX_EPSILON); } } logger.info ("Test testQuery passed."); }
public void testQuery () throws Exception { java.util.Random rand = new java.util.Random (15667); for (int mdlIdx = 0; mdlIdx < models.length; mdlIdx++) { FactorGraph mdl = models [mdlIdx]; int size = rand.nextInt (3) + 2; size = Math.min (size, mdl.varSet ().size ()); Collection vars = CollectionUtils.subset (mdl.variablesSet (), size, rand); Variable[] varArr = (Variable[]) vars.toArray (new Variable [0]); Assignment assn = new Assignment (varArr, new int [size]); BruteForceInferencer brute = new BruteForceInferencer(); Factor joint = brute.joint(mdl); double marginal = joint.marginalize(vars).value (assn); for (int algIdx = 0; algIdx < appxAlgs.length; algIdx++) { Inferencer alg = (Inferencer) appxAlgs[algIdx].newInstance(); if (alg instanceof TRP) continue; // trp can't handle disconnected models, which arise during query() double returned = alg.query (mdl, assn); assertEquals ("Failure on model "+mdlIdx+" alg "+alg, marginal, returned, APPX_EPSILON); } } logger.info ("Test testQuery passed."); }
public void ignoreTestNumMessages () { for (int mdlIdx = 0; mdlIdx < models.length; mdlIdx++) { UndirectedModel mdl = models [mdlIdx]; TRP trp = new TRP (); trp.computeMarginals (mdl); int expectedMessages = (mdl.numVariables () - 1) * 2 * trp.iterationsUsed(); assertEquals (expectedMessages, trp.getTotalMessagesSent ()); LoopyBP loopy = new LoopyBP (); loopy.computeMarginals (mdl); expectedMessages = mdl.getEdgeSet().size() * 2 * loopy.iterationsUsed(); assertEquals (expectedMessages, loopy.getTotalMessagesSent ()); } }
public void testUniformJoint () throws Exception { FactorGraph mdl = RandomGraphs.createUniformChain (3); double expected = -Math.log (8); for (int i = 0; i < allAlgs.length; i++) { Inferencer inf = (Inferencer) allAlgs[i].newInstance (); inf.computeMarginals (mdl); for (AssignmentIterator it = mdl.assignmentIterator (); it.hasNext ();) { Assignment assn = it.assignment (); double actual = inf.lookupLogJoint (assn); assertEquals ("Incorrect joint for inferencer "+inf, expected, actual, 1e-5); it.advance (); } } }
public void ignoreTestNumMessages () { for (int mdlIdx = 0; mdlIdx < models.length; mdlIdx++) { UndirectedModel mdl = models [mdlIdx]; TRP trp = new TRP (); trp.computeMarginals (mdl); int expectedMessages = (mdl.numVariables () - 1) * 2 * trp.iterationsUsed(); assertEquals (expectedMessages, trp.getTotalMessagesSent ()); LoopyBP loopy = new LoopyBP (); loopy.computeMarginals (mdl); expectedMessages = mdl.getEdgeSet().size() * 2 * loopy.iterationsUsed(); assertEquals (expectedMessages, loopy.getTotalMessagesSent ()); } }
public void testBpJoint () { for (int mdlIdx = 0; mdlIdx < trees.length; mdlIdx++) { FactorGraph mdl = trees[mdlIdx]; Inferencer bp = new TreeBP (); BruteForceInferencer brute = new BruteForceInferencer (); brute.computeMarginals (mdl); bp.computeMarginals (mdl); for (AssignmentIterator it = mdl.assignmentIterator (); it.hasNext();) { Assignment assn = (Assignment) it.next (); assertEquals (brute.lookupJoint (assn), bp.lookupJoint (assn), 1e-15); } } }
public void testBpJoint () { for (int mdlIdx = 0; mdlIdx < trees.length; mdlIdx++) { FactorGraph mdl = trees[mdlIdx]; Inferencer bp = new TreeBP (); BruteForceInferencer brute = new BruteForceInferencer (); brute.computeMarginals (mdl); bp.computeMarginals (mdl); for (AssignmentIterator it = mdl.assignmentIterator (); it.hasNext();) { Assignment assn = (Assignment) it.next (); assertEquals (brute.lookupJoint (assn), bp.lookupJoint (assn), 1e-15); } } }
public void testLoopyCaching () { FactorGraph mdl1 = models[4]; FactorGraph mdl2 = models[5]; Variable var = mdl1.get (0); LoopyBP inferencer = new LoopyBP (); inferencer.setUseCaching (true); inferencer.computeMarginals (mdl1); Factor origPtl = inferencer.lookupMarginal (var); assertTrue (2 < inferencer.iterationsUsed ()); // confuse the inferencer inferencer.computeMarginals (mdl2); // make sure we have cached, correct results inferencer.computeMarginals (mdl1); Factor sndPtl = inferencer.lookupMarginal (var); // note that we can't use an epsilon here, that's less than our convergence criteria. assertTrue ("Huh? Original potential:"+origPtl+"After: "+sndPtl, origPtl.almostEquals (sndPtl, 1e-4)); assertEquals (1, inferencer.iterationsUsed ()); }
public void testLoopyCaching () { FactorGraph mdl1 = models[4]; FactorGraph mdl2 = models[5]; Variable var = mdl1.get (0); LoopyBP inferencer = new LoopyBP (); inferencer.setUseCaching (true); inferencer.computeMarginals (mdl1); Factor origPtl = inferencer.lookupMarginal (var); assertTrue (2 < inferencer.iterationsUsed ()); // confuse the inferencer inferencer.computeMarginals (mdl2); // make sure we have cached, correct results inferencer.computeMarginals (mdl1); Factor sndPtl = inferencer.lookupMarginal (var); // note that we can't use an epsilon here, that's less than our convergence criteria. assertTrue ("Huh? Original potential:"+origPtl+"After: "+sndPtl, origPtl.almostEquals (sndPtl, 1e-4)); assertEquals (1, inferencer.iterationsUsed ()); }