@Override
public ClusteringModel encodeModel(Schema schema){
int[] shape = getClusterCentersShape();
int numberOfClusters = shape[0];
int numberOfFeatures = shape[1];
List<? extends Number> clusterCenters = getClusterCenters();
List<Integer> labels = getLabels();
Multiset<Integer> labelCounts = HashMultiset.create();
if(labels != null){
labelCounts.addAll(labels);
}
List<Cluster> clusters = new ArrayList<>();
for(int i = 0; i < numberOfClusters; i++){
Cluster cluster = new Cluster()
.setId(String.valueOf(i))
.setSize((labelCounts.size () > 0 ? labelCounts.count(i) : null))
.setArray(PMMLUtil.createRealArray(CMatrixUtil.getRow(clusterCenters, numberOfClusters, numberOfFeatures, i)));
clusters.add(cluster);
}
ComparisonMeasure comparisonMeasure = new ComparisonMeasure(ComparisonMeasure.Kind.DISTANCE)
.setCompareFunction(CompareFunction.ABS_DIFF)
.setMeasure(new SquaredEuclidean());
ClusteringModel clusteringModel = new ClusteringModel(MiningFunction.CLUSTERING, ClusteringModel.ModelClass.CENTER_BASED, numberOfClusters, ModelUtil.createMiningSchema(schema.getLabel()), comparisonMeasure, ClusteringModelUtil.createClusteringFields(schema.getFeatures()), clusters)
.setOutput(ClusteringModelUtil.createOutput(FieldName.create("Cluster"), DataType.DOUBLE, clusters));
return clusteringModel;
}