/** * Sets up a new json saver for the given network. * @param net network that will be saved */ public JsonSaver(BayesianNetwork net) { protobufSaver = new ProtobufSaver(net); }
@Override public void saveValue(BooleanVertex vertex) { protobufSaver.save(vertex); } }
private void saveParams(KeanuSavedBayesNet.Vertex.Builder vertexBuilder, Vertex vertex) { Map<String, Method> parentRetrievalMethodMap = getParentRetrievalMethodMap(vertex); String[] parentNames = parentRetrievalMethodMap.keySet().toArray(new String[0]); Arrays.sort(parentNames); for (String parentName : parentNames) { vertexBuilder.addParameters(getEncodedParam(vertex, parentName, parentRetrievalMethodMap.get(parentName))); } }
private BayesianNetwork saveLoad(final BayesianNetwork net) throws IOException { final ByteArrayOutputStream output = new ByteArrayOutputStream(); final ProtobufSaver protobufSaver = new ProtobufSaver(net); protobufSaver.save(output, true); assertThat(output.size(), greaterThan(0)); final ByteArrayInputStream input = new ByteArrayInputStream(output.toByteArray()); final ProtobufLoader loader = new ProtobufLoader(); return loader.loadNetwork(input); }
private KeanuSavedBayesNet.StoredValue getValue(IntegerVertex vertex) { KeanuSavedBayesNet.IntegerTensor savedValue = getTensor(vertex.getValue()); KeanuSavedBayesNet.VertexValue value = KeanuSavedBayesNet.VertexValue.newBuilder() .setIntVal(savedValue) .build(); return getStoredValue(vertex, value); }
private KeanuSavedBayesNet.NamedParam getTypedParam(String paramName, Object param) { if (Vertex.class.isAssignableFrom(param.getClass())) { return getParam(paramName, (Vertex)param); } else if (DoubleTensor.class.isAssignableFrom(param.getClass())){ return getParam(paramName, builder -> builder.setDoubleTensorParam(getTensor((DoubleTensor) param))); } else if (IntegerTensor.class.isAssignableFrom(param.getClass())) { return getParam(paramName, builder -> builder.setIntTensorParam(getTensor((IntegerTensor) param))); } else if (BooleanTensor.class.isAssignableFrom(param.getClass())) { return getParam(paramName, builder -> builder.setBoolTensorParam(getTensor((BooleanTensor) param))); } else if (Double.class.isAssignableFrom(param.getClass())) { return getParam(paramName, builder -> builder.setDoubleParam((double) param)); } else if (Integer.class.isAssignableFrom(param.getClass())) { return getParam(paramName, builder -> builder.setIntParam((int) param)); } else if (Long.class.isAssignableFrom(param.getClass())) { return getParam(paramName, builder -> builder.setLongParam((long) param)); } else if (String.class.isAssignableFrom(param.getClass())) { return getParam(paramName, builder -> builder.setStringParam((String) param)); } else if (Boolean.class.isAssignableFrom(param.getClass())) { return getParam(paramName, builder -> builder.setBoolParam((boolean) param)); } else if (Long[].class.isAssignableFrom(param.getClass())) { return getParam(paramName, (long[]) param); } else if (Vertex[].class.isAssignableFrom(param.getClass())) { return getParam(paramName, (Vertex[]) param); } else if (Integer[].class.isAssignableFrom(param.getClass())) { return getParam(paramName, (int[]) param); } else { throw new IllegalArgumentException("Unknown Parameter Type to Save: " + param.getClass().toString()); } }
private KeanuSavedBayesNet.NamedParam getParam(String paramName, long[] param) { return getParam(paramName, builder -> builder.setLongArrayParam( KeanuSavedBayesNet.LongArray.newBuilder().addAllValues(Longs.asList(param)))); }
@Override public void save(OutputStream output, boolean saveValues, Map<String, String> metadata) throws IOException { KeanuSavedBayesNet.Model protobufModel = getModel(saveValues, metadata); protobufModel.writeTo(output); modelBuilder = null; }
protected KeanuSavedBayesNet.Model getModel(boolean withSavedValues, Map<String, String> metadata) { createProtobufModel(withSavedValues, metadata); return modelBuilder.build(); }
@Override public void save(Vertex vertex) { if (vertex instanceof NonSaveableVertex) { throw new IllegalArgumentException("Trying to save a vertex that isn't Saveable"); } modelBuilder.getNetworkBuilder().addVertices(buildVertex(vertex)); }
private KeanuSavedBayesNet.StoredValue getValue(Vertex vertex, String formattedValue) { KeanuSavedBayesNet.GenericTensor savedValue = KeanuSavedBayesNet.GenericTensor.newBuilder() .addAllShape(Longs.asList(vertex.getShape())) .addValues(formattedValue) .build(); KeanuSavedBayesNet.VertexValue value = KeanuSavedBayesNet.VertexValue.newBuilder() .setGenericVal(savedValue) .build(); return getStoredValue(vertex, value); }
@Test public void metadataCanBeSavedToProtobuf() throws IOException { Vertex vertex = new ConstantIntegerVertex(1); BayesianNetwork net = new BayesianNetwork(vertex.getConnectedGraph()); Map<String, String> metadata = ImmutableMap.of("Author", "Some Author", "Tag", "MyBayesNet"); ByteArrayOutputStream writer = new ByteArrayOutputStream(); ProtobufSaver protobufSaver = new ProtobufSaver(net); protobufSaver.save(writer, true, metadata); KeanuSavedBayesNet.Model parsedModel = KeanuSavedBayesNet.Model.parseFrom(writer.toByteArray()); KeanuSavedBayesNet.Metadata.Builder metadataBuilder = KeanuSavedBayesNet.Metadata.newBuilder(); String[] metadataKeys = metadata.keySet().toArray(new String[0]); Arrays.sort(metadataKeys); for (String metadataKey : metadataKeys) { metadataBuilder.putMetadataInfo(metadataKey, metadata.get(metadataKey)); } assertEquals(parsedModel.getMetadata().getMetadataInfoMap(), metadataBuilder.getMetadataInfoMap()); }
private KeanuSavedBayesNet.StoredValue getValue(DoubleVertex vertex) { KeanuSavedBayesNet.DoubleTensor savedValue = getTensor(vertex.getValue()); KeanuSavedBayesNet.VertexValue value = KeanuSavedBayesNet.VertexValue.newBuilder() .setDoubleVal(savedValue) .build(); return getStoredValue(vertex, value); }
private KeanuSavedBayesNet.NamedParam getParam(String paramName, int[] param) { return getParam(paramName, builder -> builder.setIntArrayParam( KeanuSavedBayesNet.IntArray.newBuilder().addAllValues(Ints.asList(param)))); }
@Override public void save(OutputStream output, boolean saveValues, Map<String, String> metadata) throws IOException { KeanuSavedBayesNet.Model model = protobufSaver.getModel(saveValues, metadata); Writer outputWriter = new OutputStreamWriter(output); String jsonOutput = JsonFormat.printer().print(model); outputWriter.write(jsonOutput); outputWriter.close(); }
@Test public void youCanSaveAndLoadANetworkWithValues() throws IOException { final String gaussianLabel = "Gaussian"; DoubleVertex mu1 = new ConstantDoubleVertex(new double[]{3.0, 1.0}); DoubleVertex mu2 = new ConstantDoubleVertex(new double[]{5.0, 6.0}); DoubleVertex finalMu = new ConcatenationVertex(0, mu1, mu2); DoubleVertex gaussianVertex = new GaussianVertex(finalMu, 1.0); gaussianVertex.setLabel(gaussianLabel); BayesianNetwork net = new BayesianNetwork(gaussianVertex.getConnectedGraph()); ByteArrayOutputStream output = new ByteArrayOutputStream(); ProtobufSaver protobufSaver = new ProtobufSaver(net); protobufSaver.save(output, true); assertThat(output.size(), greaterThan(0)); ByteArrayInputStream input = new ByteArrayInputStream(output.toByteArray()); ProtobufLoader loader = new ProtobufLoader(); BayesianNetwork readNet = loader.loadNetwork(input); assertThat(readNet.getLatentVertices().size(), is(1)); assertThat(readNet.getLatentVertices().get(0), instanceOf(GaussianVertex.class)); GaussianVertex latentGaussianVertex = (GaussianVertex) readNet.getLatentVertices().get(0); GaussianVertex labelGaussianVerted = (GaussianVertex) readNet.getVertexByLabel(new VertexLabel(gaussianLabel)); assertThat(latentGaussianVertex, equalTo(labelGaussianVerted)); assertThat(latentGaussianVertex.getMu().getValue(0), closeTo(3.0, 1e-10)); assertThat(labelGaussianVerted.getMu().getValue(2), closeTo(5.0, 1e-10)); assertThat(latentGaussianVertex.getSigma().getValue().scalar(), closeTo(1.0, 1e-10)); latentGaussianVertex.sample(); }
@Override public void saveValue(Vertex vertex) { protobufSaver.save(vertex); }
public void saveNetToProtobuf(BayesianNetwork net, OutputStream outputStream, boolean saveValuesAndObservations) throws IOException { NetworkSaver saver = new ProtobufSaver(net); saver.save(outputStream, saveValuesAndObservations); } //%%SNIPPET_END%% SaveToProtobuf
private KeanuSavedBayesNet.StoredValue getValue(BooleanVertex vertex) { KeanuSavedBayesNet.BooleanTensor savedValue = getTensor(vertex.getValue()); KeanuSavedBayesNet.VertexValue value = KeanuSavedBayesNet.VertexValue.newBuilder() .setBoolVal(savedValue) .build(); return getStoredValue(vertex, value); }
private KeanuSavedBayesNet.NamedParam getParam(String paramName, Vertex parent) { return getParam(paramName, builder -> builder.setParentVertex( KeanuSavedBayesNet.VertexID.newBuilder().setId(parent.getId().toString()) ) ); }