@Override public VariableReference getReference() { return getId(); }
public DoubleTensor withRespectTo(Vertex vertex) { return withRespectTo(vertex.getId()); }
@Override public String toString() { StringBuilder stringBuilder = new StringBuilder(); stringBuilder.append(this.getId()); if (this.getLabel() != null) { stringBuilder.append(" (").append(this.getLabel()).append(")"); } stringBuilder.append(": "); stringBuilder.append(this.getClass().getSimpleName()); if (hasValue()) { stringBuilder.append("(" + getValue() + ")"); } return stringBuilder.toString(); }
public void setState(NetworkState state) { for (VariableReference reference : state.getVariableReferences()) { this.vertices.stream() .filter(v -> v.getId() == reference) .forEach(v -> v.setValue(state.get(reference))); } }
@Override public Writer withDefaultHeader() { List<String> header = new ArrayList<>(); for (Vertex<? extends Tensor> vertex : vertices) { for (int j = 0; j < vertex.getValue().getLength(); j++) { header.add(String.format(HEADER_STYLE, vertex.getId(), j)); } } String[] headerToArray = new String[header.size()]; withHeader(header.toArray(headerToArray)); return this; } }
@Override public Writer withDefaultHeader() { int headerSize = vertices.size(); String[] header = createHeader(headerSize, HEADER_STYLE, i -> vertices.get(i).getId().toString()); withHeader(header); return this; }
public String inDotFormat() { // Output value if value is set, but also add some descriptive info for non-constant vertices. if (!value.isEmpty()) { String dotLabel = vertex.getId().hashCode() + DOT_LABEL_OPENING + value; if (!(vertex instanceof ConstantVertex)) { dotLabel += " (" + getDescriptiveInfo() + ")"; } return dotLabel + DOT_LABEL_CLOSING; } return vertex.getId().hashCode() + DOT_LABEL_OPENING + getDescriptiveInfo() + DOT_LABEL_CLOSING; }
private static void increaseDepth(BayesianNetwork bayesianNetwork, Map<VertexLabel, Vertex> outputVertices) { VertexId newPrefix = new VertexId(); bayesianNetwork.incrementIndentation(); bayesianNetwork.getVertices().stream() .filter(v -> !outputVertices.containsKey(v.getLabel())) .forEach(v -> v.getId().addPrefix(newPrefix)); bayesianNetwork.getVertices().stream() .filter(v -> outputVertices.containsKey(v.getLabel())) .forEach(v -> v.getId().resetID()); }
private void expectRateToBeMissing(Vertex vertex) { try { double acceptanceRate = acceptanceRateTracker.getAcceptanceRate(vertex.getId()); throw new RuntimeException(String.format("Expected rate for %s to be missing but got %.2f", vertex, acceptanceRate)); } catch (IllegalStateException e) { // pass } } }
private KeanuSavedBayesNet.NamedParam getParam(String paramName, Vertex parent) { return getParam(paramName, builder -> builder.setParentVertex( KeanuSavedBayesNet.VertexID.newBuilder().setId(parent.getId().toString()) ) ); }
@Test public void writeRowOfScalarsToCsvWithHeader() throws IOException { File file = WriteCsv.asColumns(scalarTensors).withDefaultHeader().toFile(File.createTempFile("test",".csv")); CsvReader reader = ReadCsv.fromFile(file).expectHeader(true); List<List<String>> lines = reader.readLines(); VertexId firstId = scalarTensors.get(0).getId(); VertexId secondId = scalarTensors.get(1).getId(); assertTrue(lines.size() == 1); assertTrue(reader.getHeader().equals(Arrays.asList("{" + firstId + "}", "{" + secondId + "}"))); assertTrue(lines.get(0).equals(Arrays.asList("0.5", "1.5"))); file.delete(); }
@Test public void writeColumnOfTensorsToCsvWithHeader() throws IOException { File file = WriteCsv.asColumns(columnTensors).withDefaultHeader().toFile(File.createTempFile("test",".csv")); CsvReader reader = ReadCsv.fromFile(file).expectHeader(true); List<List<String>> lines = reader.readLines(); VertexId firstId = columnTensors.get(0).getId(); VertexId secondId = columnTensors.get(1).getId(); assertTrue(lines.size() == 5); assertTrue(reader.getHeader().equals(Arrays.asList("{" + firstId + "}", "{" + secondId + "}"))); assertTrue(lines.get(0).equals(Arrays.asList("1.0", "5.0"))); assertTrue(lines.get(4).equals(Arrays.asList("5.0", "-"))); file.delete(); }
@Test public void youCanTrackTheAcceptanceRateForASingleVertex() { Proposal proposal = new Proposal(); proposal.setProposal(vertex1, 1.); notifier.notifyProposalCreated(proposal); notifier.notifyProposalRejected(); assertThat(acceptanceRateTracker.getAcceptanceRate(vertex1.getId()), equalTo(0.)); proposal = new Proposal(); proposal.setProposal(vertex1, 2.); notifier.notifyProposalCreated(proposal); assertThat(acceptanceRateTracker.getAcceptanceRate(vertex1.getId()), equalTo(0.5)); }
private KeanuSavedBayesNet.NamedParam getParam(String paramName, Vertex[] param) { KeanuSavedBayesNet.VertexArray.Builder vertexArray = KeanuSavedBayesNet.VertexArray.newBuilder(); for (Vertex vertex : param) { vertexArray.addValues(KeanuSavedBayesNet.VertexID.newBuilder().setId(vertex.getId().toString())); } return getParam(paramName, builder -> builder.setVertexArrayParam(vertexArray.build())); }
private KeanuSavedBayesNet.Vertex buildVertex(Vertex vertex) { KeanuSavedBayesNet.Vertex.Builder vertexBuilder = KeanuSavedBayesNet.Vertex.newBuilder(); if (vertex.getLabel() != null) { vertexBuilder = vertexBuilder.setLabel(vertex.getLabel().toString()); } vertexBuilder = vertexBuilder.setId(KeanuSavedBayesNet.VertexID.newBuilder().setId(vertex.getId().toString())); vertexBuilder = vertexBuilder.setVertexType(vertex.getClass().getCanonicalName()); vertexBuilder = vertexBuilder.addAllShape(Longs.asList(vertex.getShape())); saveParams(vertexBuilder, vertex); return vertexBuilder.build(); }
@Test public void itThrowsIfYouAskForTheAcceptanceRateForAnUnrecognisedSetOfVertices() { expectedException.expect(IllegalStateException.class); expectedException.expectMessage("No proposals have been registered for [1]"); acceptanceRateTracker.getAcceptanceRate(vertex1.getId()); }
private KeanuSavedBayesNet.StoredValue getStoredValue(Vertex vertex, KeanuSavedBayesNet.VertexValue value) { return KeanuSavedBayesNet.StoredValue.newBuilder() .setId(KeanuSavedBayesNet.VertexID.newBuilder().setId(vertex.getId().toString()).build()) .setValue(value) .setIsObserved(vertex.isObserved()) .build(); } }