/** * @param pmml PMML model to add extension to, with no content. It may possibly duplicate * existing extensions. * @param key extension key * @param value extension value */ public static void addExtension(PMML pmml, String key, Object value) { pmml.addExtensions(new Extension().setName(key).setValue(value.toString())); }
/** * @return {@link PMML} with common {@link Header} fields like {@link Application}, * {@link Timestamp}, and version filled out */ public static PMML buildSkeletonPMML() { String formattedDate = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ssZZ", Locale.ENGLISH).format(new Date()); Header header = new Header() .setTimestamp(new Timestamp().addContent(formattedDate)) .setApplication(new Application("Oryx")); return new PMML(VERSION, header, null); }
/** * Creates a new {@link PMML} object representing the PMML model defined in the XML {@link File} specified as argument */ public static PMML newPmml(File file) throws JAXBException, SAXException, IOException { Objects.requireNonNull(file); return IOUtil.unmarshal(file); }
dataFields.add(new DataField(FieldName.create("foo"), OpType.CONTINUOUS, DataType.DOUBLE)); dataFields.add(new DataField(FieldName.create("bar"), OpType.CONTINUOUS, DataType.DOUBLE)); DataDictionary dataDictionary = new DataDictionary(dataFields).setNumberOfFields(dataFields.size()); pmml.setDataDictionary(dataDictionary); MiningField predictorMF = new MiningField(FieldName.create("foo")) .setOpType(OpType.CONTINUOUS) .setUsageType(MiningField.UsageType.ACTIVE) .setImportance(0.5); miningFields.add(predictorMF); MiningField targetMF = new MiningField(FieldName.create("bar")) .setOpType(OpType.CONTINUOUS) .setUsageType(MiningField.UsageType.PREDICTED); miningFields.add(targetMF); MiningSchema miningSchema = new MiningSchema(miningFields); Node rootNode = new Node().setId("r").setRecordCount(dummyCount).setPredicate(new True()); Node left = new Node() .setId("r-") .setRecordCount(halfCount) .setPredicate(new True()) .setScore("-2.0"); Node right = new Node().setId("r+").setRecordCount(halfCount) .setPredicate(new SimplePredicate(FieldName.create("foo"), SimplePredicate.Operator.GREATER_THAN).setValue("3.14")) .setScore("2.0");
dataFields.add(new DataField(FieldName.create("x"), OpType.CONTINUOUS, DataType.DOUBLE)); dataFields.add(new DataField(FieldName.create("y"), OpType.CONTINUOUS, DataType.DOUBLE)); DataDictionary dataDictionary = new DataDictionary(dataFields).setNumberOfFields(dataFields.size()); pmml.setDataDictionary(dataDictionary); MiningField xMF = new MiningField(FieldName.create("x")) .setOpType(OpType.CONTINUOUS).setUsageType(MiningField.UsageType.ACTIVE); miningFields.add(xMF); MiningField yMF = new MiningField(FieldName.create("y")) .setOpType(OpType.CONTINUOUS).setUsageType(MiningField.UsageType.ACTIVE); miningFields.add(yMF); MiningSchema miningSchema = new MiningSchema(miningFields); clusteringFields.add(new ClusteringField( FieldName.create("x")).setCenterField(ClusteringField.CenterField.TRUE)); clusteringFields.add(new ClusteringField( FieldName.create("y")).setCenterField(ClusteringField.CenterField.TRUE)); clusters.add(new Cluster().setId("0").setSize(1).setArray(AppPMMLUtils.toArray(1.0, 0.0))); clusters.add(new Cluster().setId("1").setSize(2).setArray(AppPMMLUtils.toArray(2.0, -1.0))); clusters.add(new Cluster().setId("2").setSize(3).setArray(AppPMMLUtils.toArray(-1.0, 0.0))); pmml.addModels(new ClusteringModel( MiningFunction.CLUSTERING, ClusteringModel.ModelClass.CENTER_BASED, clusters.size(), miningSchema, new ComparisonMeasure(ComparisonMeasure.Kind.DISTANCE).setMeasure(new SquaredEuclidean()),
@Test public void testBuildMiningSchema() { MiningSchema miningSchema = AppPMMLUtils.buildMiningSchema(buildTestSchema()); List<MiningField> miningFields = miningSchema.getMiningFields(); assertEquals(4, miningFields.size()); String[] fieldNames = { "foo", "bar", "baz", "bing" }; for (int i = 0; i < fieldNames.length; i++) { assertEquals(fieldNames[i], miningFields.get(i).getName().getValue()); } assertEquals(MiningField.UsageType.SUPPLEMENTARY, miningFields.get(0).getUsageType()); assertEquals(MiningField.UsageType.PREDICTED, miningFields.get(1).getUsageType()); assertEquals(MiningField.UsageType.SUPPLEMENTARY, miningFields.get(2).getUsageType()); assertEquals(MiningField.UsageType.ACTIVE, miningFields.get(3).getUsageType()); assertEquals(OpType.CATEGORICAL, miningFields.get(1).getOpType()); assertEquals(OpType.CONTINUOUS, miningFields.get(3).getOpType()); }
/** * @param miningSchema {@link MiningSchema} from a model * @return names of features in order */ public static List<String> getFeatureNames(MiningSchema miningSchema) { return miningSchema.getMiningFields().stream().map(field -> field.getName().getValue()) .collect(Collectors.toList()); }
/** * @param dictionary {@link DataDictionary} from model * @return names of features in order */ public static List<String> getFeatureNames(DataDictionary dictionary) { List<DataField> dataFields = dictionary.getDataFields(); Preconditions.checkArgument(dataFields != null && !dataFields.isEmpty(), "No fields in DataDictionary"); return dataFields.stream().map(field -> field.getName().getValue()).collect(Collectors.toList()); }
private static void checkDataField(DataField field, String name, Boolean categorical) { assertEquals(name, field.getName().getValue()); if (categorical == null) { assertNull(field.getOpType()); assertNull(field.getDataType()); } else if (categorical) { assertEquals(OpType.CATEGORICAL, field.getOpType()); assertEquals(DataType.STRING, field.getDataType()); } else { assertEquals(OpType.CONTINUOUS, field.getOpType()); assertEquals(DataType.DOUBLE, field.getDataType()); } }
@Test public void testBuildCategoricalEncoding() { List<DataField> dataFields = new ArrayList<>(); dataFields.add(new DataField(FieldName.create("foo"), OpType.CONTINUOUS, DataType.DOUBLE)); DataField barField = new DataField(FieldName.create("bar"), OpType.CATEGORICAL, DataType.STRING); barField.addValues(new Value("b"), new Value("a")); dataFields.add(barField); DataDictionary dictionary = new DataDictionary(dataFields).setNumberOfFields(dataFields.size()); CategoricalValueEncodings encodings = AppPMMLUtils.buildCategoricalValueEncodings(dictionary); assertEquals(2, encodings.getValueCount(1)); assertEquals(0, encodings.getValueEncodingMap(1).get("b").intValue()); assertEquals(1, encodings.getValueEncodingMap(1).get("a").intValue()); assertEquals("b", encodings.getEncodingValueMap(1).get(0)); assertEquals("a", encodings.getEncodingValueMap(1).get(1)); assertEquals(Collections.singletonMap(1, 2), encodings.getCategoryCounts()); }
@Test public void testFromString() throws Exception { PMML model = buildDummyModel(); PMML model2 = PMMLUtils.fromString(PMMLUtils.toString(model)); assertEquals(model.getHeader().getApplication().getName(), model2.getHeader().getApplication().getName()); assertEquals(model.getModels().get(0).getMiningFunction(), model2.getModels().get(0).getMiningFunction()); }
/** * @param miningSchema {@link MiningSchema} from a model * @return index of the {@link MiningField.UsageType#PREDICTED} feature */ public static Integer findTargetIndex(MiningSchema miningSchema) { List<MiningField> miningFields = miningSchema.getMiningFields(); for (int i = 0; i < miningFields.size(); i++) { if (miningFields.get(i).getUsageType() == MiningField.UsageType.PREDICTED) { return i; } } return null; }
/** * @return The raw inputs extracted from the tuple for all 'active fields' */ @Override public Map<FieldName, Object> extractRawInputs(Tuple tuple) { LOG.debug("Extracting raw inputs from tuple: = [{}]", tuple); final Map<FieldName, Object> rawInputs = new LinkedHashMap<>(); for (FieldName activeField : activeFields) { rawInputs.put(activeField, tuple.getValueByField(activeField.getValue())); } LOG.debug("Raw inputs = [{}]", rawInputs); return rawInputs; }
/** * @param pmml PMML model to add extension to, with a single {@code String} content and no value. * The content is encoded as if they were being added to a PMML {@link Array} and are * space-separated with PMML quoting rules * @param key extension key * @param content list of values to add as a {@code String} */ public static void addExtensionContent(PMML pmml, String key, Collection<?> content) { if (content.isEmpty()) { return; } String joined = TextUtils.joinPMMLDelimited(content); pmml.addExtensions(new Extension().setName(key).addContent(joined)); }
@Test public void testSkeleton() { PMML pmml = PMMLUtils.buildSkeletonPMML(); assertEquals("Oryx", pmml.getHeader().getApplication().getName()); assertNotNull(pmml.getHeader().getTimestamp()); }
protected static void checkHeader(Header header) { assertNotNull(header); assertNotNull(header.getTimestamp()); assertEquals("Oryx", header.getApplication().getName()); }
public static String getExtensionValue(PMML pmml, String name) { return pmml.getExtensions().stream().filter(extension -> name.equals(extension.getName())).findFirst(). map(Extension::getValue).orElse(null); }
/** * @param values {@code double} value to make into a PMML {@link Array} * @return PMML {@link Array} representation */ public static Array toArray(double... values) { List<Double> valueList = new ArrayList<>(values.length); for (double value : values) { valueList.add(value); } String arrayValue = TextUtils.joinPMMLDelimitedNumbers(valueList); return new Array(Array.Type.REAL, arrayValue).setN(valueList.size()); }
@Test public void testToString() throws Exception { PMML model = buildDummyModel(); model.getHeader().setTimestamp(null); assertEquals("<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"yes\"?>" + "<PMML version=\"4.3\" xmlns=\"http://www.dmg.org/PMML-4_3\">" + "<Header>" + "<Application name=\"Oryx\"/>" + "</Header>" + "<TreeModel functionName=\"classification\">" + "<Node recordCount=\"123.0\"/>" + "</TreeModel>" + "</PMML>", PMMLUtils.toString(model)); }
/** * Creates a new {@link PMML} object representing the PMML model defined in the {@link InputStream} specified as argument */ public static PMML newPmml(InputStream stream) throws JAXBException, SAXException, IOException { Objects.requireNonNull(stream); return IOUtil.unmarshal(stream); }