/** * Validates that the encoded PMML model received matches expected schema. * * @param pmml {@link PMML} encoding of KMeans Clustering * @param schema expected schema attributes of KMeans Clustering */ public static void validatePMMLVsSchema(PMML pmml, InputSchema schema) { List<Model> models = pmml.getModels(); Preconditions.checkArgument(models.size() == 1, "Should have exactly one model, but had %s", models.size()); Model model = models.get(0); Preconditions.checkArgument(model instanceof ClusteringModel); Preconditions.checkArgument(model.getMiningFunction() == MiningFunction.CLUSTERING); DataDictionary dictionary = pmml.getDataDictionary(); Preconditions.checkArgument( schema.getFeatureNames().equals(AppPMMLUtils.getFeatureNames(dictionary)), "Feature names in schema don't match names in PMML"); MiningSchema miningSchema = model.getMiningSchema(); Preconditions.checkArgument(schema.getFeatureNames().equals( AppPMMLUtils.getFeatureNames(miningSchema))); }
/** * @param pmml PMML representation of Clusters * @return List of {@link ClusterInfo} */ public static List<ClusterInfo> read(PMML pmml) { Model model = pmml.getModels().get(0); Preconditions.checkArgument(model instanceof ClusteringModel); ClusteringModel clusteringModel = (ClusteringModel) model; return clusteringModel.getClusters().stream().map(cluster -> new ClusterInfo(Integer.parseInt(cluster.getId()), VectorMath.parseVector(TextUtils.parseDelimited(cluster.getArray().getValue(), ' ')), cluster.getSize()) ).collect(Collectors.toList()); }
List<Model> models = pmml.getModels(); Preconditions.checkArgument(models.size() == 1, "Should have exactly one model, but had %s", models.size());
AppPMMLUtils.buildCategoricalValueEncodings(dictionary); List<Model> models = pmml.getModels(); Model model = models.get(0); MiningSchema miningSchema = model.getMiningSchema();
Model rootModel = pmml.getModels().get(0); ClusteringModel clusteringModel = (ClusteringModel) rootModel;
@Test public void testReadWrite() throws Exception { Path tempModelFile = Files.createTempFile(getTempDir(), "model", ".pmml"); PMML model = buildDummyModel(); PMMLUtils.write(model, tempModelFile); assertTrue(Files.exists(tempModelFile)); PMML model2 = PMMLUtils.read(tempModelFile); List<Model> models = model2.getModels(); assertEquals(1, models.size()); assertInstanceOf(models.get(0), TreeModel.class); TreeModel treeModel = (TreeModel) models.get(0); assertEquals(123.0, treeModel.getNode().getRecordCount().doubleValue()); assertEquals(MiningFunction.CLASSIFICATION, treeModel.getMiningFunction()); }
@Test public void testFromString() throws Exception { PMML model = buildDummyModel(); PMML model2 = PMMLUtils.fromString(PMMLUtils.toString(model)); assertEquals(model.getHeader().getApplication().getName(), model2.getHeader().getApplication().getName()); assertEquals(model.getModels().get(0).getMiningFunction(), model2.getModels().get(0).getMiningFunction()); }
Model model = pmml.getModels().get(0); assertInstanceOf(model, ClusteringModel.class);
Model rootModel = pmml.getModels().get(0);
Model rootModel = pmml.getModels().get(0); if (rootModel instanceof TreeModel) { assertEquals(NUM_TREES, 1);
public static Model getModelByName(PMML pmml, String name) { for(Model model: pmml.getModels()) { if(model.getModelName().equals(name)) { return model; } } throw new RuntimeException("No such model: " + name); }
public static Model getModelByName(PMML pmml, String name) { for(Model model: pmml.getModels()) { if(model.getModelName().equals(name)) { return model; } } throw new RuntimeException("No such model: " + name); }
@Override public PMML transform(PMML pmml){ List<Model> models = pmml.getModels(); for(Model model : models){ if(model instanceof MiningModel){ MiningModel miningModel = (MiningModel)model; transform(miningModel); } } return pmml; }
@PostConstruct public void setUp() throws IOException, SAXException, JAXBException { try (InputStream is = properties.getModelLocation().getInputStream()) { Source transformedSource = ImportFilter.apply(new InputSource(is)); pmml = JAXBUtil.unmarshalPMML(transformedSource); Assert.state(!pmml.getModels().isEmpty(), "The provided PMML file at " + properties.getModelLocation() + " does not contain any model"); } }
@PostConstruct public void setUp() throws IOException, SAXException, JAXBException { try (InputStream is = properties.getModelLocation().getInputStream()) { Source transformedSource = ImportFilter.apply(new InputSource(is)); pmml = JAXBUtil.unmarshalPMML(transformedSource); Assert.state(!pmml.getModels().isEmpty(), "The provided PMML file at " + properties.getModelLocation() + " does not contain any model"); } }
static public Model findModel(PMML pmml, Predicate<Model> predicate, String predicateXPath){ if(!pmml.hasModels()){ throw new MissingElementException(MissingElementException.formatMessage(XPathUtil.formatElement(pmml.getClass()) + "/" + predicateXPath), pmml); } List<Model> models = pmml.getModels(); Optional<Model> result = models.stream() .filter(predicate) .findAny(); if(!result.isPresent()){ throw new MissingElementException(MissingElementException.formatMessage(XPathUtil.formatElement(pmml.getClass()) + "/" + predicateXPath), pmml); } return result.get(); } }
@Test public void cleanChained() throws Exception { PMML pmml = ResourceUtil.unmarshal(ChainedSegmentationTest.class); DataDictionary dataDictionary = pmml.getDataDictionary(); checkFields(FieldNameUtil.create("y", "x1", "x2", "x3", "x4"), dataDictionary.getDataFields()); DataDictionaryCleaner cleaner = new DataDictionaryCleaner(); cleaner.applyTo(pmml); checkFields(FieldNameUtil.create("y", "x1", "x2", "x3"), dataDictionary.getDataFields()); List<Model> models = pmml.getModels(); models.clear(); cleaner.applyTo(pmml); checkFields(Collections.emptySet(), dataDictionary.getDataFields()); }
@Test public void cleanNested() throws Exception { PMML pmml = ResourceUtil.unmarshal(NestedSegmentationTest.class); DataDictionary dataDictionary = pmml.getDataDictionary(); checkFields(FieldNameUtil.create("y", "x1", "x2", "x3", "x4", "x5"), dataDictionary.getDataFields()); DataDictionaryCleaner cleaner = new DataDictionaryCleaner(); cleaner.applyTo(pmml); checkFields(FieldNameUtil.create("x1", "x2", "x3", "x4", "x5"), dataDictionary.getDataFields()); List<Model> models = pmml.getModels(); models.clear(); cleaner.applyTo(pmml); checkFields(Collections.emptySet(), dataDictionary.getDataFields()); }
@Test public void filterChainedSegmentation() throws Exception { PMML pmml = ResourceUtil.unmarshal(ChainedSegmentationTest.class, new SkipFilter("Segmentation")); assertNotNull(pmml.getDataDictionary()); assertNotNull(pmml.getTransformationDictionary()); List<Model> models = pmml.getModels(); MiningModel miningModel = (MiningModel)models.get(0); assertNotNull(miningModel.getMiningSchema()); assertNotNull(miningModel.getOutput()); assertNull(miningModel.getSegmentation()); }
@Test public void filterNestedSegmentation() throws Exception { PMML pmml = ResourceUtil.unmarshal(NestedSegmentationTest.class, new SkipFilter(Segmentation.class)); assertNotNull(pmml.getDataDictionary()); List<Model> models = pmml.getModels(); MiningModel miningModel = (MiningModel)models.get(0); assertNotNull(miningModel.getMiningSchema()); assertNotNull(miningModel.getLocalTransformations()); assertNotNull(miningModel.getOutput()); assertNull(miningModel.getSegmentation()); }