private double computeELBO() { double elbo = this.vmp.getNodes().parallelStream().filter(node -> node.isActive() && !node.isObserved()).mapToDouble(node -> this.vmp.computeELBO(node)).sum(); elbo += this.vmp.getNodes() .parallelStream() .filter(node -> node.isActive() && node.isObserved()).mapToDouble(node -> { EF_BaseDistribution_MultinomialParents base = (EF_BaseDistribution_MultinomialParents) node.getPDist(); Variable topicVariable = (Variable) base.getMultinomialParents().get(0); Map<Variable, MomentParameters> momentParents = node.getMomentParents(); double localELBO = 0; MomentParameters topicMoments = momentParents.get(topicVariable); int wordIndex = (int) node.getAssignment().getValue(node.getMainVariable())%node.getMainVariable().getNumberOfStates(); for (int i = 0; i < topicMoments.size(); i++) { EF_SparseMultinomial_Dirichlet dist = (EF_SparseMultinomial_Dirichlet)base.getBaseEFConditionalDistribution(i); MomentParameters dirichletMoments = momentParents.get(dist.getDirichletVariable()); localELBO += node.getSufficientStatistics().get(wordIndex)*dirichletMoments.get(wordIndex)*topicMoments.get(i); } return localELBO; }).sum(); return elbo; }
double elbo = this.vmp.getNodes().stream().filter(node -> node.isActive() && !node.isObserved()).mapToDouble(node -> this.vmp.computeELBO(node)).sum();
svb.updateNaturalParameterPrior(prior); svb.updateNaturalParameterPosteriors(updatedPrior); basedELBO = svb.getPlateuStructure().getNonReplictedNodes().mapToDouble(node -> svb.getPlateuStructure().getVMP().computeELBO(node)).sum(); basedELBO = svb.getPlateuStructure().getNonReplictedNodes().mapToDouble(node -> svb.getPlateuStructure().getVMP().computeELBO(node)).sum();
svb.updateNaturalParameterPrior(prior); svb.updateNaturalParameterPosteriors(updatedPrior); basedELBO = svb.getPlateuStructure().getNonReplictedNodes().mapToDouble(node -> svb.getPlateuStructure().getVMP().computeELBO(node)).sum(); }else{ this.prior=Serialization.deepCopy(updatedPrior); basedELBO = svb.getPlateuStructure().getNonReplictedNodes().mapToDouble(node -> svb.getPlateuStructure().getVMP().computeELBO(node)).sum();
KL += svb.getPlateuStructure().getNonReplictedNodes().mapToDouble(node -> svb.getPlateuStructure().getVMP().computeELBO(node)).sum();
public static double computeELBO(DataFlink<DataInstance> dataFlink, SVB svb, Function2<DataFlink<DataInstance>,Integer,DataSet<DataOnMemory<DataInstance>>> batchConverter){ svb.setOutput(false); double elbo = svb.getPlateuStructure().getNonReplictedNodes().mapToDouble(node -> svb.getPlateuStructure().getVMP().computeELBO(node)).sum(); try { Configuration config = new Configuration(); config.setBytes(SVB, Serialization.serializeObject(svb)); config.setBytes(PRIOR, Serialization.serializeObject(svb.getPlateuStructure().getPlateauNaturalParameterPosterior())); DataSet<DataOnMemory<DataInstance>> batches; if (batchConverter!=null) batches= dataFlink.getBatchedDataSet(svb.getWindowsSize(),batchConverter); else batches= dataFlink.getBatchedDataSet(svb.getWindowsSize()); elbo += batches.map(new ParallelVBMapELBO()) .withParameters(config) .reduce(new ReduceFunction<Double>() { @Override public Double reduce(Double aDouble, Double t1) throws Exception { return aDouble + t1; } }).collect().get(0); } catch (Exception e) { e.printStackTrace(); } svb.setOutput(true); return elbo; }
svb.updateNaturalParameterPrior(prior); svb.updateNaturalParameterPosteriors(updatedPosterior); basedELBO = svb.getPlateuStructure().getNonReplictedNodes().mapToDouble(node -> svb.getPlateuStructure().getVMP().computeELBO(node)).sum(); basedELBO = svb.getPlateuStructure().getNonReplictedNodes().mapToDouble(node -> svb.getPlateuStructure().getVMP().computeELBO(node)).sum();
svb.updateNaturalParameterPrior(prior); svb.updateNaturalParameterPosteriors(updatedPosterior); basedELBO = svb.getPlateuStructure().getNonReplictedNodes().mapToDouble(node -> svb.getPlateuStructure().getVMP().computeELBO(node)).sum(); }else{ this.prior=Serialization.deepCopy(updatedPosterior); basedELBO = svb.getPlateuStructure().getNonReplictedNodes().mapToDouble(node -> svb.getPlateuStructure().getVMP().computeELBO(node)).sum();
svb.updateNaturalParameterPrior(prior); svb.updateNaturalParameterPosteriors(updatedPosterior); basedELBO = svb.getPlateuStructure().getNonReplictedNodes().mapToDouble(node -> svb.getPlateuStructure().getVMP().computeELBO(node)).sum(); }else{ this.prior=Serialization.deepCopy(updatedPosterior); basedELBO = svb.getPlateuStructure().getNonReplictedNodes().mapToDouble(node -> svb.getPlateuStructure().getVMP().computeELBO(node)).sum();