public void initLearning() { //VMPParameter vmpParameter = new VMPParameter(this.svb.getPlateuStructure()); //vmpParameter.setMaxGlobaIter(1); //this.svb.getPlateuStructure().setVmp(vmpParameter); this.svb.getPlateuStructure().getVMP().setMaxIter(this.maximumLocalIterations); this.svb.getPlateuStructure().getVMP().setThreshold(this.localThreshold); this.svb.setDAG(this.dag); this.svb.setWindowsSize(batchSize); this.svb.initLearning(); //Init learning is peformed in each mapper. }
/** * Activate the output for the underlying MessagePassingAlgorithm. * @param output a {@code boolean} that represents the output value to be set. */ public void setOutput(boolean output){ this.plateauStructure.getVMPTime0().setOutput(output); this.plateauStructure.getVMPTimeT().setOutput(output); }
private boolean testConvergence() { boolean convergence = false; //Compute lower-bound double newelbo = computeELBO(); double percentage = 100 * Math.abs(newelbo - local_elbo) / Math.abs(local_elbo); if (percentage < this.vmp.getThreshold() || local_iter > this.vmp.getMaxIter()) { convergence = true; } if ((!convergence && (newelbo / this.vmp.getNodes().size() < (local_elbo / this.vmp.getNodes().size() - 0.01)) && local_iter > -1) || Double.isNaN(local_elbo)) { throw new IllegalStateException("The elbo is not monotonically increasing at iter " + local_iter + ": " + percentage + ", " + local_elbo + ", " + newelbo); } local_elbo = newelbo; return convergence; }
/** * {@inheritDoc} */ @Override public <E extends UnivariateDistribution> E getPredictivePosterior(Variable var, int nTimesAhead) { if (timeID==-1){ this.vmpTime0.setEvidence(null); this.vmpTime0.runInference(); this.vmpTime0.getNodes().stream().filter(node -> !node.isObserved()).forEach(node -> { Variable temporalClone = this.model.getDynamicVariables().getInterfaceVariable(node.getMainVariable()); moveNodeQDist(this.vmpTimeT.getNodeOfVar(temporalClone), node); }); this.moveWindow(nTimesAhead-1); E resultQ = this.getFilteredPosterior(var); this.vmpTime0.resetQs(); this.vmpTimeT.resetQs(); return resultQ; }else { Map<Variable, EF_UnivariateDistribution> map = new HashMap<>(); //Create at copy of Qs this.vmpTimeT.getNodes().stream().filter(node -> !node.isObserved()).forEach(node -> map.put(node.getMainVariable(), node.getQDist().deepCopy())); this.moveWindow(nTimesAhead); E resultQ = this.getFilteredPosterior(var); //Come to the original state map.entrySet().forEach(e -> this.vmpTimeT.getNodeOfVar(e.getKey()).setQDist(e.getValue())); return resultQ; } }
local_elbo = Double.NEGATIVE_INFINITY; local_iter = 0; while (!convergence && (local_iter++) < this.vmp.getMaxIter()) { continue; Message<NaturalParameters> selfMessage = this.vmp.newSelfMessage(node); .map(children -> this.vmp.newMessageToParent(children, node)) .reduce(Message::combineNonStateless); this.vmp.updateCombinedMessage(node, selfMessage); .filter(node -> node.isActive() && !node.isObserved()) .forEach(node -> { Message<NaturalParameters> selfMessage = this.vmp.newSelfMessage(node); .map(children -> this.vmp.newMessageToParent(children, node)) .reduce(Message::combineNonStateless); this.vmp.updateCombinedMessage(node, selfMessage); }); if (this.vmp.isOutput()) { System.out.println("N Iter: " + local_iter + ", elbo:" + local_elbo);
int superstep = getIterationRuntimeContext().getSuperstepNumber() - 1; if (INITIALIZE && superstep==0) { VMP vmp = new VMP(); vmp.setMaxIter(this.svb.getPlateuStructure().getVMP().getMaxIter()); vmp.setThreshold(this.svb.getPlateuStructure().getVMP().getThreshold()); vmp.setTestELBO(this.svb.getPlateuStructure().getVMP().isOutput()); this.svb.getPlateuStructure().setVmp(vmp); svb.updateNaturalParameterPrior(prior); svb.updateNaturalParameterPosteriors(updatedPosterior); basedELBO = svb.getPlateuStructure().getNonReplictedNodes().mapToDouble(node -> svb.getPlateuStructure().getVMP().computeELBO(node)).sum(); }else{ this.prior=Serialization.deepCopy(updatedPosterior); basedELBO = svb.getPlateuStructure().getNonReplictedNodes().mapToDouble(node -> svb.getPlateuStructure().getVMP().computeELBO(node)).sum();
public static void main(String[] args) { DataStream<DataInstance> dataInstances = DataStreamLoader.open("/Users/andresmasegosa/Dropbox/Amidst/datasets/NFSAbstracts/abstractByYear/abstract_90.arff"); //DataOnMemory<DataInstance> dataInstances = DataStreamLoader.loadDataOnMemoryFromFile("/Users/andresmasegosa/Dropbox/Amidst/datasets/NFSAbstracts/abstractByYear/abstract_90.arff"); SVB svb = new SVB(); PlateauLDA plateauLDA = new PlateauLDA(dataInstances.getAttributes(),"word","count"); plateauLDA.setNTopics(10); plateauLDA.getVMP().setTestELBO(true); plateauLDA.getVMP().setMaxIter(10); plateauLDA.getVMP().setOutput(true); plateauLDA.getVMP().setThreshold(0.1); svb.setPlateuStructure(plateauLDA); svb.setOutput(true); svb.initLearning(); //System.out.println(dataInstances.getNumberOfDataInstances()); //svb.updateModel(dataInstances); BatchSpliteratorByID.streamOverDocuments(dataInstances, 500).sequential().forEach(batch -> { System.out.println("Batch: "+ batch.getNumberOfDataInstances()); svb.updateModel(batch); }); }
switch(searchAlgorithm) { case VMP: staticModelInference = new VMP(); ((VMP)staticModelInference).setThreshold(0.0001); ((VMP)staticModelInference).setMaxIter(3000); break;
@Override protected void initLearning() { if(learningAlgorithm==null) { SVB svb = new SVB(); plateauLDA = new PlateauLDA(this.atts, "word", "count"); plateauLDA.setNTopics(ntopics); svb.setPlateuStructure(plateauLDA); svb.getPlateuStructure().getVMP().setTestELBO(false); svb.getPlateuStructure().getVMP().setMaxIter(100); svb.getPlateuStructure().getVMP().setThreshold(0.01); learningAlgorithm = svb; } learningAlgorithm.setWindowsSize(100); learningAlgorithm.setOutput(true); learningAlgorithm.initLearning(); initialized=true; }
/** * Sets the maximum number of iterations for this MessagePassingAlgorithm. * @param maxIter a {@code int} that represents the maximum number of iterations to be set. */ public void setMaxIter(int maxIter){ this.plateauStructure.getVMPTime0().setMaxIter(maxIter); this.plateauStructure.getVMPTimeT().setMaxIter(maxIter); }
/** * Sets the threshold for this MessagePassingAlgorithm. * @param threshold a {@code double} that represents the threshold value to be set. */ public void setThreshold(double threshold) { this.plateauStructure.getVMPTime0().setThreshold(threshold); this.plateauStructure.getVMPTimeT().setThreshold(threshold); }
private double computeELBO() { double elbo = this.vmp.getNodes().parallelStream().filter(node -> node.isActive() && !node.isObserved()).mapToDouble(node -> this.vmp.computeELBO(node)).sum(); elbo += this.vmp.getNodes() .parallelStream() .filter(node -> node.isActive() && node.isObserved()).mapToDouble(node -> { EF_BaseDistribution_MultinomialParents base = (EF_BaseDistribution_MultinomialParents) node.getPDist(); Variable topicVariable = (Variable) base.getMultinomialParents().get(0); Map<Variable, MomentParameters> momentParents = node.getMomentParents(); double localELBO = 0; MomentParameters topicMoments = momentParents.get(topicVariable); int wordIndex = (int) node.getAssignment().getValue(node.getMainVariable())%node.getMainVariable().getNumberOfStates(); for (int i = 0; i < topicMoments.size(); i++) { EF_SparseMultinomial_Dirichlet dist = (EF_SparseMultinomial_Dirichlet)base.getBaseEFConditionalDistribution(i); MomentParameters dirichletMoments = momentParents.get(dist.getDirichletVariable()); localELBO += node.getSufficientStatistics().get(wordIndex)*dirichletMoments.get(wordIndex)*topicMoments.get(i); } return localELBO; }).sum(); return elbo; }
svb.updateNaturalParameterPrior(prior); svb.updateNaturalParameterPosteriors(updatedPrior); basedELBO = svb.getPlateuStructure().getNonReplictedNodes().mapToDouble(node -> svb.getPlateuStructure().getVMP().computeELBO(node)).sum(); basedELBO = svb.getPlateuStructure().getNonReplictedNodes().mapToDouble(node -> svb.getPlateuStructure().getVMP().computeELBO(node)).sum();
/** * Runs inference at time 0. */ public void runInferenceTime0() { this.vmpTime0.runInference(); this.vmpTime0.getNodes().stream().filter(node -> !node.isObserved() && !node.getMainVariable().isParameterVariable()).forEach(node -> { Variable temporalClone = this.dbnModel.getDynamicVariables().getInterfaceVariable(node.getMainVariable()); moveNodeQDist(this.getNodeOfVarTimeT(temporalClone,0), node); }); }
if (!Double.isNaN(elbo) && percentageIncrease<this.plateuStructure.getVMP().getThreshold()){ convergence=true;
int superstep = getIterationRuntimeContext().getSuperstepNumber() - 1; if (INITIALIZE && superstep==0) { VMP vmp = new VMP(); vmp.setMaxIter(this.svb.getPlateuStructure().getVMP().getMaxIter()); vmp.setThreshold(this.svb.getPlateuStructure().getVMP().getThreshold()); vmp.setTestELBO(this.svb.getPlateuStructure().getVMP().isOutput()); this.svb.getPlateuStructure().setVmp(vmp); svb.updateNaturalParameterPrior(prior); svb.updateNaturalParameterPosteriors(updatedPosterior); basedELBO = svb.getPlateuStructure().getNonReplictedNodes().mapToDouble(node -> svb.getPlateuStructure().getVMP().computeELBO(node)).sum(); }else{ this.prior=Serialization.deepCopy(updatedPosterior); basedELBO = svb.getPlateuStructure().getNonReplictedNodes().mapToDouble(node -> svb.getPlateuStructure().getVMP().computeELBO(node)).sum();
svb.setSeed(0); VMP vmp = svb.getPlateuStructure().getVMP(); vmp.setOutput(false); vmp.setTestELBO(true); vmp.setMaxIter(1000); vmp.setThreshold(0.0001);
switch (searchAlgorithm) { case VMP: currentModelInference = new VMP(); ((VMP)currentModelInference).setThreshold(0.0001); ((VMP) currentModelInference).setMaxIter(3000); break;
protected void initLearning() { if(learningAlgorithm==null) { SVB svb = new SVB(); svb.setWindowsSize(100); svb.getPlateuStructure().getVMP().setTestELBO(false); svb.getPlateuStructure().getVMP().setMaxIter(100); svb.getPlateuStructure().getVMP().setThreshold(0.00001); learningAlgorithm = svb; } learningAlgorithm.setWindowsSize(windowSize); if (this.getDAG()!=null) learningAlgorithm.setDAG(this.getDAG()); else if (this.getPlateuStructure()!=null) ((BayesianParameterLearningAlgorithm)learningAlgorithm).setPlateuStructure(this.getPlateuStructure()); else throw new IllegalArgumentException("Non provided dag or PlateauStructure"); learningAlgorithm.setOutput(true); learningAlgorithm.initLearning(); initialized=true; }