private KeanuSavedBayesNet.NamedParam getParam(String paramName, Consumer<KeanuSavedBayesNet.NamedParam.Builder> valueSetter) { KeanuSavedBayesNet.NamedParam.Builder paramBuilder = KeanuSavedBayesNet.NamedParam.newBuilder(); paramBuilder.setName(paramName); valueSetter.accept(paramBuilder); return paramBuilder.build(); }
.setLabel("muVertex") .setVertexType(ConstantDoubleVertex.class.getCanonicalName()) .addParameters(KeanuSavedBayesNet.NamedParam.newBuilder() .setName("constant") .setDoubleTensorParam(KeanuSavedBayesNet.DoubleTensor.newBuilder() .setLabel("sigmaVertex") .setVertexType(ConstantDoubleVertex.class.getCanonicalName()) .addParameters(KeanuSavedBayesNet.NamedParam.newBuilder() .setName("constant") .setDoubleTensorParam(KeanuSavedBayesNet.DoubleTensor.newBuilder() .setLabel(GAUSS_LABEL) .setVertexType(GaussianVertex.class.getCanonicalName()) .addParameters(KeanuSavedBayesNet.NamedParam.newBuilder() .setName("mu") .setParentVertex(KeanuSavedBayesNet.VertexID.newBuilder().setId("1").build()) .build() ).addParameters(KeanuSavedBayesNet.NamedParam.newBuilder() .setName("sigma") .setParentVertex(KeanuSavedBayesNet.VertexID.newBuilder().setId("2").build())
.setLabel("MU VERTEX") .setVertexType(ConstantDoubleVertex.class.getCanonicalName()) .addParameters(KeanuSavedBayesNet.NamedParam.newBuilder() .setName("constant") .setDoubleTensorParam(KeanuSavedBayesNet.DoubleTensor.newBuilder() .setLabel("GAUSSIAN VERTEX") .setVertexType(GaussianVertex.class.getCanonicalName()) .addParameters(KeanuSavedBayesNet.NamedParam.newBuilder() .setName("mu") .setParentVertex(KeanuSavedBayesNet.VertexID.newBuilder().setId("1").build())
.setId(KeanuSavedBayesNet.VertexID.newBuilder().setId("1")) .setVertexType(ConstantDoubleVertex.class.getCanonicalName()) .addParameters(KeanuSavedBayesNet.NamedParam.newBuilder() .setName("constant") .setDoubleTensorParam(KeanuSavedBayesNet.DoubleTensor.newBuilder()
private Object getDecodedParam(KeanuSavedBayesNet.NamedParam parameter, Map<KeanuSavedBayesNet.VertexID, Vertex> existingVertices) { switch (parameter.getParamCase()) { case PARENTVERTEX: return existingVertices.get(parameter.getParentVertex()); return extractDoubleTensor(parameter.getDoubleTensorParam()); return extractIntTensor(parameter.getIntTensorParam()); return extractBoolTensor(parameter.getBoolTensorParam()); return parameter.getDoubleParam(); return parameter.getIntParam(); return parameter.getLongParam(); return parameter.getStringParam(); return parameter.getBoolParam(); return Longs.toArray(parameter.getLongArrayParam().getValuesList()); return Ints.toArray(parameter.getIntArrayParam().getValuesList()); + parameter.getParamCase().toString());
@Test public void loadFailsIfWrongArgumentTypeSpecified() throws IOException { expectedException.expect(IllegalArgumentException.class); expectedException.expectMessage("Incorrect Parameter Type specified. " + "Got: class io.improbable.keanu.tensor.intgr.ScalarIntegerTensor, " + "Expected: interface io.improbable.keanu.tensor.dbl.DoubleTensor"); KeanuSavedBayesNet.Vertex constantVertex = KeanuSavedBayesNet.Vertex.newBuilder() .setId(KeanuSavedBayesNet.VertexID.newBuilder().setId("1")) .setVertexType(ConstantDoubleVertex.class.getCanonicalName()) .addParameters(KeanuSavedBayesNet.NamedParam.newBuilder() .setName("constant") .setIntTensorParam(KeanuSavedBayesNet.IntegerTensor.newBuilder() .addAllShape(Longs.asList()).addValues(1).build() ).build()) .build(); KeanuSavedBayesNet.BayesianNetwork savedNet = KeanuSavedBayesNet.BayesianNetwork.newBuilder() .addVertices(constantVertex).build(); KeanuSavedBayesNet.Model savedModel = KeanuSavedBayesNet.Model.newBuilder() .setNetwork(savedNet) .build(); ByteArrayOutputStream writer = new ByteArrayOutputStream(); savedModel.writeTo(writer); ProtobufLoader loader = new ProtobufLoader(); BayesianNetwork readNet = loader.loadNetwork(new ByteArrayInputStream(writer.toByteArray())); }