private KeanuSavedBayesNet.Vertex buildVertex(Vertex vertex) { KeanuSavedBayesNet.Vertex.Builder vertexBuilder = KeanuSavedBayesNet.Vertex.newBuilder(); if (vertex.getLabel() != null) { vertexBuilder = vertexBuilder.setLabel(vertex.getLabel().toString()); } vertexBuilder = vertexBuilder.setId(KeanuSavedBayesNet.VertexID.newBuilder().setId(vertex.getId().toString())); vertexBuilder = vertexBuilder.setVertexType(vertex.getClass().getCanonicalName()); vertexBuilder = vertexBuilder.addAllShape(Longs.asList(vertex.getShape())); saveParams(vertexBuilder, vertex); return vertexBuilder.build(); }
@Test public void loadFailsIfNoConstantSpecified() throws IOException { expectedException.expect(IllegalArgumentException.class); expectedException.expectMessage("Failed to create vertex due to missing parent: constant"); KeanuSavedBayesNet.Vertex constantVertex = KeanuSavedBayesNet.Vertex.newBuilder() .setId(KeanuSavedBayesNet.VertexID.newBuilder().setId("1")) .setVertexType(ConstantDoubleVertex.class.getCanonicalName()) .build(); KeanuSavedBayesNet.BayesianNetwork savedNet = KeanuSavedBayesNet.BayesianNetwork.newBuilder() .addVertices(constantVertex).build(); KeanuSavedBayesNet.Model savedModel = KeanuSavedBayesNet.Model.newBuilder() .setNetwork(savedNet) .build(); ByteArrayOutputStream writer = new ByteArrayOutputStream(); savedModel.writeTo(writer); ProtobufLoader loader = new ProtobufLoader(); BayesianNetwork readNet = loader.loadNetwork(new ByteArrayInputStream(writer.toByteArray())); }
@Test public void loadFailsIfInvalidVertexSpecified() throws IOException { expectedException.expect(IllegalArgumentException.class); expectedException.expectMessage("Unknown Vertex Type Specified: made.up.vertex.NonExistentVertex"); KeanuSavedBayesNet.Vertex constantVertex = KeanuSavedBayesNet.Vertex.newBuilder() .setId(KeanuSavedBayesNet.VertexID.newBuilder().setId("1")) .setVertexType("made.up.vertex.NonExistentVertex") .build(); KeanuSavedBayesNet.BayesianNetwork savedNet = KeanuSavedBayesNet.BayesianNetwork.newBuilder() .addVertices(constantVertex).build(); KeanuSavedBayesNet.Model savedModel = KeanuSavedBayesNet.Model.newBuilder() .setNetwork(savedNet) .build(); ByteArrayOutputStream writer = new ByteArrayOutputStream(); savedModel.writeTo(writer); ProtobufLoader loader = new ProtobufLoader(); BayesianNetwork readNet = loader.loadNetwork(new ByteArrayInputStream(writer.toByteArray())); }
.setId(KeanuSavedBayesNet.VertexID.newBuilder().setId("1")) .setLabel("muVertex") .setVertexType(ConstantDoubleVertex.class.getCanonicalName()) .setId(KeanuSavedBayesNet.VertexID.newBuilder().setId("2")) .setLabel("sigmaVertex") .setVertexType(ConstantDoubleVertex.class.getCanonicalName()) .setId(KeanuSavedBayesNet.VertexID.newBuilder().setId(GAUSS_ID)) .setLabel(GAUSS_LABEL) .setVertexType(GaussianVertex.class.getCanonicalName()) .addParameters(KeanuSavedBayesNet.NamedParam.newBuilder() .setName("mu") .setParentVertex(KeanuSavedBayesNet.VertexID.newBuilder().setId("1").build()) .build() ).addParameters(KeanuSavedBayesNet.NamedParam.newBuilder() .setName("sigma") .setParentVertex(KeanuSavedBayesNet.VertexID.newBuilder().setId("2").build()) .build() .setId(KeanuSavedBayesNet.VertexID.newBuilder().setId(idForValue)) .setValue(KeanuSavedBayesNet.VertexValue.newBuilder() .setDoubleVal(KeanuSavedBayesNet.DoubleTensor.newBuilder()
.setId(KeanuSavedBayesNet.VertexID.newBuilder().setId("1")) .setLabel("MU VERTEX") .setVertexType(ConstantDoubleVertex.class.getCanonicalName()) .setId(KeanuSavedBayesNet.VertexID.newBuilder().setId("2")) .setLabel("GAUSSIAN VERTEX") .setVertexType(GaussianVertex.class.getCanonicalName()) .addParameters(KeanuSavedBayesNet.NamedParam.newBuilder() .setName("mu") .setParentVertex(KeanuSavedBayesNet.VertexID.newBuilder().setId("1").build()) .build()
.setId(KeanuSavedBayesNet.VertexID.newBuilder().setId("1")) .setVertexType(ConstantDoubleVertex.class.getCanonicalName()) .addParameters(KeanuSavedBayesNet.NamedParam.newBuilder() .setId(KeanuSavedBayesNet.VertexID.newBuilder().setId("1")) .setValue(KeanuSavedBayesNet.VertexValue.newBuilder() .setIntVal(KeanuSavedBayesNet.IntegerTensor.newBuilder()
@Test public void loadFailsIfWrongArgumentTypeSpecified() throws IOException { expectedException.expect(IllegalArgumentException.class); expectedException.expectMessage("Incorrect Parameter Type specified. " + "Got: class io.improbable.keanu.tensor.intgr.ScalarIntegerTensor, " + "Expected: interface io.improbable.keanu.tensor.dbl.DoubleTensor"); KeanuSavedBayesNet.Vertex constantVertex = KeanuSavedBayesNet.Vertex.newBuilder() .setId(KeanuSavedBayesNet.VertexID.newBuilder().setId("1")) .setVertexType(ConstantDoubleVertex.class.getCanonicalName()) .addParameters(KeanuSavedBayesNet.NamedParam.newBuilder() .setName("constant") .setIntTensorParam(KeanuSavedBayesNet.IntegerTensor.newBuilder() .addAllShape(Longs.asList()).addValues(1).build() ).build()) .build(); KeanuSavedBayesNet.BayesianNetwork savedNet = KeanuSavedBayesNet.BayesianNetwork.newBuilder() .addVertices(constantVertex).build(); KeanuSavedBayesNet.Model savedModel = KeanuSavedBayesNet.Model.newBuilder() .setNetwork(savedNet) .build(); ByteArrayOutputStream writer = new ByteArrayOutputStream(); savedModel.writeTo(writer); ProtobufLoader loader = new ProtobufLoader(); BayesianNetwork readNet = loader.loadNetwork(new ByteArrayInputStream(writer.toByteArray())); }
private KeanuSavedBayesNet.StoredValue getStoredValue(Vertex vertex, KeanuSavedBayesNet.VertexValue value) { return KeanuSavedBayesNet.StoredValue.newBuilder() .setId(KeanuSavedBayesNet.VertexID.newBuilder().setId(vertex.getId().toString()).build()) .setValue(value) .setIsObserved(vertex.isObserved()) .build(); } }
private KeanuSavedBayesNet.NamedParam getParam(String paramName, Vertex[] param) { KeanuSavedBayesNet.VertexArray.Builder vertexArray = KeanuSavedBayesNet.VertexArray.newBuilder(); for (Vertex vertex : param) { vertexArray.addValues(KeanuSavedBayesNet.VertexID.newBuilder().setId(vertex.getId().toString())); } return getParam(paramName, builder -> builder.setVertexArrayParam(vertexArray.build())); }
private KeanuSavedBayesNet.NamedParam getParam(String paramName, Vertex parent) { return getParam(paramName, builder -> builder.setParentVertex( KeanuSavedBayesNet.VertexID.newBuilder().setId(parent.getId().toString()) ) ); }