@Test public void testReadWrite() throws Exception { Path tempModelFile = Files.createTempFile(getTempDir(), "model", ".pmml"); PMML model = buildDummyModel(); PMMLUtils.write(model, tempModelFile); assertTrue(Files.exists(tempModelFile)); PMML model2 = PMMLUtils.read(tempModelFile); List<Model> models = model2.getModels(); assertEquals(1, models.size()); assertInstanceOf(models.get(0), TreeModel.class); TreeModel treeModel = (TreeModel) models.get(0); assertEquals(123.0, treeModel.getNode().getRecordCount().doubleValue()); assertEquals(MiningFunction.CLASSIFICATION, treeModel.getMiningFunction()); }
@Test public void testToString() throws Exception { PMML model = buildDummyModel(); model.getHeader().setTimestamp(null); assertEquals("<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"yes\"?>" + "<PMML version=\"4.3\" xmlns=\"http://www.dmg.org/PMML-4_3\">" + "<Header>" + "<Application name=\"Oryx\"/>" + "</Header>" + "<TreeModel functionName=\"classification\">" + "<Node recordCount=\"123.0\"/>" + "</TreeModel>" + "</PMML>", PMMLUtils.toString(model)); }
@Override public PMML buildModel(JavaSparkContext sparkContext, JavaRDD<String> trainData, List<?> hyperParameters, Path candidatePath) { // If lists are unequal at this point, there must have been an empty test set // which yielded no call to evaluate(). Fill in the blank while (trainCounts.size() > testCounts.size()) { testCounts.add(0); } trainCounts.add((int) trainData.count()); return PMMLUtilsTest.buildDummyModel(); }
@Test public void testSkeleton() { PMML pmml = PMMLUtils.buildSkeletonPMML(); assertEquals("Oryx", pmml.getHeader().getApplication().getName()); assertNotNull(pmml.getHeader().getTimestamp()); }
@Test public void testPreviousPMMLVersion() throws Exception { String pmml42 = "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"yes\"?>\n" + "<PMML xmlns=\"http://www.dmg.org/PMML-4_2\" version=\"4.2.1\">\n" + " <Header>\n" + " <Application name=\"Oryx\"/>\n" + " </Header>\n" + " <TreeModel functionName=\"classification\">\n" + " <Node recordCount=\"123.0\"/>\n" + " </TreeModel>\n" + "</PMML>\n"; PMML model = PMMLUtils.fromString(pmml42); // Actually transforms to latest version: assertEquals(PMMLUtils.VERSION, model.getVersion()); }
@Test public void testFromString() throws Exception { PMML model = buildDummyModel(); PMML model2 = PMMLUtils.fromString(PMMLUtils.toString(model)); assertEquals(model.getHeader().getApplication().getName(), model2.getHeader().getApplication().getName()); assertEquals(model.getModels().get(0).getMiningFunction(), model2.getModels().get(0).getMiningFunction()); }