public void testDirectedJt () { DirectedModel bn = createDirectedModel (); BruteForceInferencer brute = new BruteForceInferencer (); brute.computeMarginals (bn); JunctionTreeInferencer jt = new JunctionTreeInferencer (); jt.computeMarginals (bn); compareMarginals ("Error comparing junction tree to brute on directed model!", bn, brute, jt); }
public void testDirectedJt () { DirectedModel bn = createDirectedModel (); BruteForceInferencer brute = new BruteForceInferencer (); brute.computeMarginals (bn); JunctionTreeInferencer jt = new JunctionTreeInferencer (); jt.computeMarginals (bn); compareMarginals ("Error comparing junction tree to brute on directed model!", bn, brute, jt); }
private void testSerializationForAlg (Inferencer alg) throws IOException, ClassNotFoundException { for (int mdlIdx = 0; mdlIdx < models.length; mdlIdx++) { FactorGraph mdl = models [mdlIdx]; // Copy the inferencer before calling b/c of random seed issues. Inferencer alg2 = (Inferencer) TestSerializable.cloneViaSerialization (alg); alg.computeMarginals(mdl); Factor[] pre = collectAllMarginals (mdl, alg); alg2.computeMarginals (mdl); Factor[] post2 = collectAllMarginals (mdl, alg2); compareMarginals ("Error comparing marginals after serialzation on model "+mdl, pre, post2); } }
private void testSerializationForAlg (Inferencer alg) throws IOException, ClassNotFoundException { for (int mdlIdx = 0; mdlIdx < models.length; mdlIdx++) { FactorGraph mdl = models [mdlIdx]; // Copy the inferencer before calling b/c of random seed issues. Inferencer alg2 = (Inferencer) TestSerializable.cloneViaSerialization (alg); alg.computeMarginals(mdl); Factor[] pre = collectAllMarginals (mdl, alg); alg2.computeMarginals (mdl); Factor[] post2 = collectAllMarginals (mdl, alg2); compareMarginals ("Error comparing marginals after serialzation on model "+mdl, pre, post2); } }
public void testTrpTreeList () { FactorGraph model = createTriangle(); model.getVariable (0).setLabel ("V0"); model.getVariable (1).setLabel ("V1"); model.getVariable (2).setLabel ("V2"); List readers = new ArrayList (); for (int i = 0; i < treeStrs.length; i++) { readers.add (new StringReader (treeStrs[i])); } TRP trp = new TRP().setTerminator (new TRP.DefaultConvergenceTerminator()) .setFactory (TRP.TreeListFactory.makeFromReaders (model, readers)); trp.computeMarginals(model); Inferencer jt = new BruteForceInferencer (); jt.computeMarginals (model); compareMarginals ("", model, trp, jt); }
public void testTrpTreeList () { FactorGraph model = createTriangle(); model.getVariable (0).setLabel ("V0"); model.getVariable (1).setLabel ("V1"); model.getVariable (2).setLabel ("V2"); List readers = new ArrayList (); for (int i = 0; i < treeStrs.length; i++) { readers.add (new StringReader (treeStrs[i])); } TRP trp = new TRP().setTerminator (new TRP.DefaultConvergenceTerminator()) .setFactory (TRP.TreeListFactory.makeFromReaders (model, readers)); trp.computeMarginals(model); Inferencer jt = new BruteForceInferencer (); jt.computeMarginals (model); compareMarginals ("", model, trp, jt); }