private ClusteringModel pmmlClusteringModel(KMeansModel model, Map<Integer,Long> clusterSizesMap) { Vector[] clusterCenters = model.clusterCenters(); List<ClusteringField> clusteringFields = new ArrayList<>(); for (int i = 0; i < inputSchema.getNumFeatures(); i++) { if (inputSchema.isActive(i)) { FieldName fieldName = FieldName.create(inputSchema.getFeatureNames().get(i)); ClusteringField clusteringField = new ClusteringField(fieldName).setCenterField(ClusteringField.CenterField.TRUE); clusteringFields.add(clusteringField); } } List<Cluster> clusters = new ArrayList<>(clusterCenters.length); for (int i = 0; i < clusterCenters.length; i++) { clusters.add(new Cluster().setId(Integer.toString(i)) .setSize(clusterSizesMap.get(i).intValue()) .setArray(AppPMMLUtils.toArray(clusterCenters[i].toArray()))); } return new ClusteringModel( MiningFunction.CLUSTERING, ClusteringModel.ModelClass.CENTER_BASED, clusters.size(), AppPMMLUtils.buildMiningSchema(inputSchema), new ComparisonMeasure(ComparisonMeasure.Kind.DISTANCE).setMeasure(new SquaredEuclidean()), clusteringFields, clusters); }
assertEquals(NUM_CLUSTERS, clusteringModel.getNumberOfClusters()); assertEquals(NUM_CLUSTERS, clusteringModel.getClusters().size()); assertEquals(NUM_FEATURES, clusteringModel.getClusteringFields().size()); assertEquals(ComparisonMeasure.Kind.DISTANCE, clusteringModel.getComparisonMeasure().getKind()); assertEquals(NUM_FEATURES, clusteringModel.getClusters().get(0).getArray().getN().intValue()); for (Cluster cluster : clusteringModel.getClusters()) { assertGreater(cluster.getSize(), 0);
assertEquals(NUM_CLUSTERS, clusteringModel.getNumberOfClusters()); List<Cluster> clusters = clusteringModel.getClusters();
@Override public VisitorAction accept(Visitor visitor) { VisitorAction status = visitor.visit(this); if (status == VisitorAction.CONTINUE) { visitor.pushParent(this); if ((status == VisitorAction.CONTINUE)&&hasExtensions()) { status = PMMLObject.traverse(visitor, getExtensions()); } if (status == VisitorAction.CONTINUE) { status = PMMLObject.traverse(visitor, getMiningSchema(), getOutput(), getModelStats(), getModelExplanation(), getLocalTransformations(), getComparisonMeasure()); } if ((status == VisitorAction.CONTINUE)&&hasClusteringFields()) { status = PMMLObject.traverse(visitor, getClusteringFields()); } if (status == VisitorAction.CONTINUE) { status = PMMLObject.traverse(visitor, getCenterFields(), getMissingValueWeights()); } if ((status == VisitorAction.CONTINUE)&&hasClusters()) { status = PMMLObject.traverse(visitor, getClusters()); } if (status == VisitorAction.CONTINUE) { status = PMMLObject.traverse(visitor, getModelVerification()); } visitor.popParent(); } if (status == VisitorAction.TERMINATE) { return VisitorAction.TERMINATE; } return VisitorAction.CONTINUE; }
/** * @param pmml PMML representation of Clusters * @return List of {@link ClusterInfo} */ public static List<ClusterInfo> read(PMML pmml) { Model model = pmml.getModels().get(0); Preconditions.checkArgument(model instanceof ClusteringModel); ClusteringModel clusteringModel = (ClusteringModel) model; return clusteringModel.getClusters().stream().map(cluster -> new ClusterInfo(Integer.parseInt(cluster.getId()), VectorMath.parseVector(TextUtils.parseDelimited(cluster.getArray().getValue(), ' ')), cluster.getSize()) ).collect(Collectors.toList()); }
@Override public ClusteringModel encodeModel(Schema schema){ int[] shape = getClusterCentersShape(); int numberOfClusters = shape[0]; int numberOfFeatures = shape[1]; List<? extends Number> clusterCenters = getClusterCenters(); List<Integer> labels = getLabels(); Multiset<Integer> labelCounts = HashMultiset.create(); if(labels != null){ labelCounts.addAll(labels); } List<Cluster> clusters = new ArrayList<>(); for(int i = 0; i < numberOfClusters; i++){ Cluster cluster = new Cluster() .setId(String.valueOf(i)) .setSize((labelCounts.size () > 0 ? labelCounts.count(i) : null)) .setArray(PMMLUtil.createRealArray(CMatrixUtil.getRow(clusterCenters, numberOfClusters, numberOfFeatures, i))); clusters.add(cluster); } ComparisonMeasure comparisonMeasure = new ComparisonMeasure(ComparisonMeasure.Kind.DISTANCE) .setCompareFunction(CompareFunction.ABS_DIFF) .setMeasure(new SquaredEuclidean()); ClusteringModel clusteringModel = new ClusteringModel(MiningFunction.CLUSTERING, ClusteringModel.ModelClass.CENTER_BASED, numberOfClusters, ModelUtil.createMiningSchema(schema.getLabel()), comparisonMeasure, ClusteringModelUtil.createClusteringFields(schema.getFeatures()), clusters) .setOutput(ClusteringModelUtil.createOutput(FieldName.create("Cluster"), DataType.DOUBLE, clusters)); return clusteringModel; }
public ClusteringModelEvaluator(PMML pmml, ClusteringModel clusteringModel){ super(pmml, clusteringModel); ComparisonMeasure comparisonMeasure = clusteringModel.getComparisonMeasure(); if(comparisonMeasure == null){ throw new MissingElementException(clusteringModel, PMMLElements.CLUSTERINGMODEL_COMPARISONMEASURE); } ClusteringModel.ModelClass modelClass = clusteringModel.getModelClass(); switch(modelClass){ case CENTER_BASED: break; default: throw new UnsupportedAttributeException(clusteringModel, modelClass); } CenterFields centerFields = clusteringModel.getCenterFields(); if(centerFields != null){ throw new UnsupportedElementException(centerFields); } if(!clusteringModel.hasClusteringFields()){ throw new MissingElementException(clusteringModel, PMMLElements.CLUSTERINGMODEL_CLUSTERINGFIELDS); } // End if if(!clusteringModel.hasClusters()){ throw new MissingElementException(clusteringModel, PMMLElements.CLUSTERINGMODEL_CLUSTERS); } Targets targets = clusteringModel.getTargets(); if(targets != null){ throw new MisplacedElementException(targets); } }
private <V extends Number> ClusterAffinityDistribution<V> evaluateDistance(ValueFactory<V> valueFactory, ComparisonMeasure comparisonMeasure, List<ClusteringField> clusteringFields, List<FieldValue> values){ ClusteringModel clusteringModel = getModel(); List<Cluster> clusters = clusteringModel.getClusters(); Value<V> adjustment; MissingValueWeights missingValueWeights = clusteringModel.getMissingValueWeights(); if(missingValueWeights != null){ Array array = missingValueWeights.getArray(); List<? extends Number> adjustmentValues = ArrayUtil.asNumberList(array); if(values.size() != adjustmentValues.size()){ throw new InvalidElementException(missingValueWeights); } adjustment = MeasureUtil.calculateAdjustment(valueFactory, values, adjustmentValues); } else { adjustment = MeasureUtil.calculateAdjustment(valueFactory, values); } ClusterAffinityDistribution<V> result = createClusterAffinityDistribution(Classification.Type.DISTANCE, clusters); for(Cluster cluster : clusters){ List<FieldValue> clusterValues = CacheUtil.getValue(cluster, ClusteringModelEvaluator.clusterValueCache); if(values.size() != clusterValues.size()){ throw new InvalidElementException(cluster); } Value<V> distance = MeasureUtil.evaluateDistance(valueFactory, comparisonMeasure, clusteringFields, values, clusterValues, adjustment); result.put(cluster, distance); } return result; }
public ClusteringModel addClusteringFields(ClusteringField... clusteringFields) { getClusteringFields().addAll(Arrays.asList(clusteringFields)); return this; }
assertEquals(100, clusteringModel.getNumberOfClusters());
@Override protected <V extends Number> Map<FieldName, ClusterAffinityDistribution<V>> evaluateClustering(ValueFactory<V> valueFactory, EvaluationContext context){ ClusteringModel clusteringModel = getModel(); ComparisonMeasure comparisonMeasure = clusteringModel.getComparisonMeasure(); List<ClusteringField> clusteringFields = getCenterClusteringFields(); List<FieldValue> values = new ArrayList<>(clusteringFields.size()); for(int i = 0, max = clusteringFields.size(); i < max; i++){ ClusteringField clusteringField = clusteringFields.get(i); FieldName name = clusteringField.getField(); if(name == null){ throw new MissingAttributeException(clusteringField, PMMLAttributes.CLUSTERINGFIELD_FIELD); } FieldValue value = context.evaluate(name); values.add(value); } ClusterAffinityDistribution<V> result; Measure measure = MeasureUtil.ensureMeasure(comparisonMeasure); if(measure instanceof Similarity){ result = evaluateSimilarity(valueFactory, comparisonMeasure, clusteringFields, values); } else if(measure instanceof Distance){ result = evaluateDistance(valueFactory, comparisonMeasure, clusteringFields, values); } else { throw new UnsupportedElementException(measure); } // "For clustering models, the identifier of the winning cluster is returned as the predictedValue" result.computeResult(DataType.STRING); return Collections.singletonMap(getTargetName(), result); }
@Override public VisitorAction accept(Visitor visitor) { VisitorAction status = visitor.visit(this); if (status == VisitorAction.CONTINUE) { visitor.pushParent(this); if ((status == VisitorAction.CONTINUE)&&hasExtensions()) { status = PMMLObject.traverse(visitor, getExtensions()); } if (status == VisitorAction.CONTINUE) { status = PMMLObject.traverse(visitor, getMiningSchema(), getOutput(), getModelStats(), getModelExplanation(), getLocalTransformations(), getComparisonMeasure()); } if ((status == VisitorAction.CONTINUE)&&hasClusteringFields()) { status = PMMLObject.traverse(visitor, getClusteringFields()); } if (status == VisitorAction.CONTINUE) { status = PMMLObject.traverse(visitor, getCenterFields(), getMissingValueWeights()); } if ((status == VisitorAction.CONTINUE)&&hasClusters()) { status = PMMLObject.traverse(visitor, getClusters()); } if (status == VisitorAction.CONTINUE) { status = PMMLObject.traverse(visitor, getModelVerification()); } visitor.popParent(); } if (status == VisitorAction.TERMINATE) { return VisitorAction.TERMINATE; } return VisitorAction.CONTINUE; }
@Override public Collection<?> getCollection(){ return clusteringModel.getClusters(); } });
.setMeasure(new SquaredEuclidean()); ClusteringModel clusteringModel = new ClusteringModel(MiningFunction.CLUSTERING, ClusteringModel.ModelClass.CENTER_BASED, rows, ModelUtil.createMiningSchema(schema.getLabel()), comparisonMeasure, ClusteringModelUtil.createClusteringFields(schema.getFeatures()), clusters) .setOutput(ClusteringModelUtil.createOutput(FieldName.create("cluster"), DataType.DOUBLE, clusters));
public ClusteringModel addClusteringFields(ClusteringField... clusteringFields) { getClusteringFields().addAll(Arrays.asList(clusteringFields)); return this; }
@Override public Integer getSize(){ return clusteringModel.getNumberOfClusters(); }
clusters.add(new Cluster().setId("2").setSize(3).setArray(AppPMMLUtils.toArray(-1.0, 0.0))); pmml.addModels(new ClusteringModel( MiningFunction.CLUSTERING, ClusteringModel.ModelClass.CENTER_BASED,
public ClusteringModel addClusters(Cluster... clusters) { getClusters().addAll(Arrays.asList(clusters)); return this; }
private List<ClusteringField> getCenterClusteringFields(){ ClusteringModel clusteringModel = getModel(); List<ClusteringField> clusteringFields = clusteringModel.getClusteringFields(); List<ClusteringField> result = new ArrayList<>(clusteringFields.size()); for(int i = 0, max = clusteringFields.size(); i < max; i++){ ClusteringField clusteringField = clusteringFields.get(i); ClusteringField.CenterField centerField = clusteringField.getCenterField(); switch(centerField){ case TRUE: result.add(clusteringField); break; case FALSE: break; default: throw new UnsupportedAttributeException(clusteringField, centerField); } } return result; }
/** * Create an instance of {@link ClusteringModel } * */ public ClusteringModel createClusteringModel() { return new ClusteringModel(); }