/** * Returns the exponential family variable posterior at time 0 for a given {@link Variable}. * @param var a given {@link Variable} object. * @param <E> a subtype distribution of {@link EF_UnivariateDistribution}. * @return an {@link EF_UnivariateDistribution} object. */ public <E extends EF_UnivariateDistribution> E getEFParameterPosteriorTime0(Variable var) { if (!var.isParameterVariable()) throw new IllegalArgumentException("Only parameter variables can be queried"); return (E)this.parametersToNodeTime0.get(var).getQDist(); }
/** * Returns the exponential family variable posterior at time T for a given {@link Variable}. * @param var a given {@link Variable} object. * @param <E> a subtype distribution of {@link EF_UnivariateDistribution}. * @return an {@link EF_UnivariateDistribution} object. */ public <E extends EF_UnivariateDistribution> E getEFParameterPosteriorTimeT(Variable var) { if (!var.isParameterVariable()) throw new IllegalArgumentException("Only parameter variables can be queried"); return (E)this.parametersToNodeTimeT.get(var).getQDist(); }
public void setConstraint(Node node){ for ( Constraint constraint : this.constraintMap.get(node.getMainVariable())) { Optional<Node> optional = node.getParents().stream().filter(nodeParent -> eu.amidst.core.constraints.Constraints.match(node.getPDist(), nodeParent.getMainVariable(), constraint)).findFirst(); if (!optional.isPresent()) throw new IllegalStateException("No constraint for the given parameter"); Node nodeParent = optional.get(); nodeParent.setActive(false); eu.amidst.core.constraints.Constraints.fixValue(nodeParent.getQDist(), constraint.getValue()); } } }
/** * Moves the exponential family distributions. * @param toTemporalCloneNode a {@link Node} object. * @param fromNode a {@link Node} object. */ private static void moveNodeQDist(Node toTemporalCloneNode, Node fromNode){ EF_UnivariateDistribution uni = fromNode.getQDist().deepCopy(toTemporalCloneNode.getMainVariable()); toTemporalCloneNode.setPDist(uni); toTemporalCloneNode.setQDist(uni); }
private static void moveNodeQDist(Node toTemporalCloneNode, Node fromNode){ EF_UnivariateDistribution uni = fromNode.getQDist().deepCopy(toTemporalCloneNode.getMainVariable()); toTemporalCloneNode.setPDist(uni); toTemporalCloneNode.setQDist(uni); }
/** * {@inheritDoc} */ @Override public <E extends UnivariateDistribution> E getPredictivePosterior(Variable var, int nTimesAhead) { if (timeID==-1){ this.vmpTime0.setEvidence(null); this.vmpTime0.runInference(); this.vmpTime0.getNodes().stream().filter(node -> !node.isObserved()).forEach(node -> { Variable temporalClone = this.model.getDynamicVariables().getInterfaceVariable(node.getMainVariable()); moveNodeQDist(this.vmpTimeT.getNodeOfVar(temporalClone), node); }); this.moveWindow(nTimesAhead-1); E resultQ = this.getFilteredPosterior(var); this.vmpTime0.resetQs(); this.vmpTimeT.resetQs(); return resultQ; }else { Map<Variable, EF_UnivariateDistribution> map = new HashMap<>(); //Create at copy of Qs this.vmpTimeT.getNodes().stream().filter(node -> !node.isObserved()).forEach(node -> map.put(node.getMainVariable(), node.getQDist().deepCopy())); this.moveWindow(nTimesAhead); E resultQ = this.getFilteredPosterior(var); //Come to the original state map.entrySet().forEach(e -> this.vmpTimeT.getNodeOfVar(e.getKey()).setQDist(e.getValue())); return resultQ; } }
for (Node node : this.plateuStructure.getNonReplictedNodes().collect(Collectors.toList())) { Map<Variable, MomentParameters> momentParents = node.getMomentParents(); kl_q_p0[count] = local_kl((EF_Dirichlet)node.getQDist(),(EF_Dirichlet)node.getPDist()); count++; for (Node node : this.plateuStructure.getNonReplictedNodes().collect(Collectors.toList())) { Map<Variable, MomentParameters> momentParents = node.getMomentParents(); kl_q_pt_1[count] = local_kl((EF_Dirichlet)node.getQDist(),(EF_Dirichlet)node.getPDist()); count++;
for (Node node : this.plateuStructure.getNonReplictedNodes().collect(Collectors.toList())) { EF_Dirichlet dirichletQ = (EF_Dirichlet)node.getQDist(); EF_Dirichlet dirichletP = (EF_Dirichlet)node.getPDist(); double[] localKL = computeLocalKLDirichletBinary(dirichletQ,dirichletP); for (Node node : this.plateuStructure.getNonReplictedNodes().collect(Collectors.toList())) { EF_Dirichlet dirichletQ = (EF_Dirichlet)node.getQDist(); EF_Dirichlet dirichletP = (EF_Dirichlet)node.getPDist(); double[] localKL = computeLocalKLDirichletBinary(dirichletQ,dirichletP);