private <V extends Number> AffinityDistribution<V> createAffinityDistribution(List<InstanceResult<V>> instanceResults, Function<Integer, String> function, Object result){ NearestNeighborModel nearestNeighborModel = getModel(); ComparisonMeasure comparisonMeasure = nearestNeighborModel.getComparisonMeasure(); ValueMap<String, V> values = new ValueMap<>(2 * instanceResults.size()); for(InstanceResult<V> instanceResult : instanceResults){ values.put(function.apply(instanceResult.getId()), instanceResult.getValue()); } Measure measure = MeasureUtil.ensureMeasure(comparisonMeasure); if(measure instanceof Similarity){ return new AffinityDistribution<>(Classification.Type.SIMILARITY, values, result); } else if(measure instanceof Distance){ return new AffinityDistribution<>(Classification.Type.DISTANCE, values, result); } else { throw new UnsupportedElementException(measure); } }
private <V extends Number> List<InstanceResult<V>> evaluateInstanceRows(ValueFactory<V> valueFactory, EvaluationContext context){ NearestNeighborModel nearestNeighborModel = getModel(); ComparisonMeasure comparisonMeasure = nearestNeighborModel.getComparisonMeasure(); List<FieldValue> values = new ArrayList<>(); KNNInputs knnInputs = nearestNeighborModel.getKNNInputs(); for(KNNInput knnInput : knnInputs){ FieldName name = knnInput.getField(); if(name == null){ throw new MissingAttributeException(knnInput, PMMLAttributes.KNNINPUT_FIELD); } FieldValue value = context.evaluate(name); values.add(value); } Measure measure = MeasureUtil.ensureMeasure(comparisonMeasure); if(measure instanceof Similarity){ return evaluateSimilarity(valueFactory, comparisonMeasure, knnInputs.getKNNInputs(), values); } else if(measure instanceof Distance){ return evaluateDistance(valueFactory, comparisonMeasure, knnInputs.getKNNInputs(), values); } else { throw new UnsupportedElementException(measure); } }
public NearestNeighborModelEvaluator(PMML pmml, NearestNeighborModel nearestNeighborModel){ super(pmml, nearestNeighborModel); ComparisonMeasure comparisoonMeasure = nearestNeighborModel.getComparisonMeasure(); if(comparisoonMeasure == null){ throw new MissingElementException(nearestNeighborModel, PMMLElements.NEARESTNEIGHBORMODEL_COMPARISONMEASURE); } TrainingInstances trainingInstances = nearestNeighborModel.getTrainingInstances(); if(trainingInstances == null){ throw new MissingElementException(nearestNeighborModel, PMMLElements.NEARESTNEIGHBORMODEL_TRAININGINSTANCES); } InstanceFields instanceFields = trainingInstances.getInstanceFields(); if(instanceFields == null){ throw new MissingElementException(trainingInstances, PMMLElements.TRAININGINSTANCES_INSTANCEFIELDS); } // End if if(!instanceFields.hasInstanceFields()){ throw new MissingElementException(instanceFields, PMMLElements.INSTANCEFIELDS_INSTANCEFIELDS); } KNNInputs knnInputs = nearestNeighborModel.getKNNInputs(); if(knnInputs == null){ throw new MissingElementException(nearestNeighborModel, PMMLElements.NEARESTNEIGHBORMODEL_KNNINPUTS); } // End if if(!knnInputs.hasKNNInputs()){ throw new MissingElementException(knnInputs, PMMLElements.KNNINPUTS_KNNINPUTS); } }
@Override public VisitorAction accept(Visitor visitor) { VisitorAction status = visitor.visit(this); if (status == VisitorAction.CONTINUE) { visitor.pushParent(this); if ((status == VisitorAction.CONTINUE)&&hasExtensions()) { status = PMMLObject.traverse(visitor, getExtensions()); } if (status == VisitorAction.CONTINUE) { status = PMMLObject.traverse(visitor, getMiningSchema(), getOutput(), getModelStats(), getModelExplanation(), getTargets(), getLocalTransformations(), getTrainingInstances(), getComparisonMeasure(), getKNNInputs(), getModelVerification()); } visitor.popParent(); } if (status == VisitorAction.TERMINATE) { return VisitorAction.TERMINATE; } return VisitorAction.CONTINUE; }
@Override public VisitorAction accept(Visitor visitor) { VisitorAction status = visitor.visit(this); if (status == VisitorAction.CONTINUE) { visitor.pushParent(this); if ((status == VisitorAction.CONTINUE)&&hasExtensions()) { status = PMMLObject.traverse(visitor, getExtensions()); } if (status == VisitorAction.CONTINUE) { status = PMMLObject.traverse(visitor, getMiningSchema(), getOutput(), getModelStats(), getModelExplanation(), getTargets(), getLocalTransformations(), getTrainingInstances(), getComparisonMeasure(), getKNNInputs(), getModelVerification()); } visitor.popParent(); } if (status == VisitorAction.TERMINATE) { return VisitorAction.TERMINATE; } return VisitorAction.CONTINUE; }