/** * Create an instance of {@link NearestNeighborModel } * */ public NearestNeighborModel createNearestNeighborModel() { return new NearestNeighborModel(); }
@Override public VisitorAction accept(Visitor visitor) { VisitorAction status = visitor.visit(this); if (status == VisitorAction.CONTINUE) { visitor.pushParent(this); if ((status == VisitorAction.CONTINUE)&&hasExtensions()) { status = PMMLObject.traverse(visitor, getExtensions()); } if (status == VisitorAction.CONTINUE) { status = PMMLObject.traverse(visitor, getMiningSchema(), getOutput(), getModelStats(), getModelExplanation(), getTargets(), getLocalTransformations(), getTrainingInstances(), getComparisonMeasure(), getKNNInputs(), getModelVerification()); } visitor.popParent(); } if (status == VisitorAction.TERMINATE) { return VisitorAction.TERMINATE; } return VisitorAction.CONTINUE; }
public NearestNeighborModelEvaluator(PMML pmml, NearestNeighborModel nearestNeighborModel){ super(pmml, nearestNeighborModel); ComparisonMeasure comparisoonMeasure = nearestNeighborModel.getComparisonMeasure(); if(comparisoonMeasure == null){ throw new MissingElementException(nearestNeighborModel, PMMLElements.NEARESTNEIGHBORMODEL_COMPARISONMEASURE); } TrainingInstances trainingInstances = nearestNeighborModel.getTrainingInstances(); if(trainingInstances == null){ throw new MissingElementException(nearestNeighborModel, PMMLElements.NEARESTNEIGHBORMODEL_TRAININGINSTANCES); } InstanceFields instanceFields = trainingInstances.getInstanceFields(); if(instanceFields == null){ throw new MissingElementException(trainingInstances, PMMLElements.TRAININGINSTANCES_INSTANCEFIELDS); } // End if if(!instanceFields.hasInstanceFields()){ throw new MissingElementException(instanceFields, PMMLElements.INSTANCEFIELDS_INSTANCEFIELDS); } KNNInputs knnInputs = nearestNeighborModel.getKNNInputs(); if(knnInputs == null){ throw new MissingElementException(nearestNeighborModel, PMMLElements.NEARESTNEIGHBORMODEL_KNNINPUTS); } // End if if(!knnInputs.hasKNNInputs()){ throw new MissingElementException(knnInputs, PMMLElements.KNNINPUTS_KNNINPUTS); } }
NearestNeighborModel nearestNeighborModel = modelEvaluator.getModel(); FieldName instanceIdVariable = nearestNeighborModel.getInstanceIdVariable(); TrainingInstances trainingInstances = nearestNeighborModel.getTrainingInstances(); KNNInputs knnInputs = nearestNeighborModel.getKNNInputs(); for(KNNInput knnInput : knnInputs){ FieldName name = knnInput.getField(); int numberOfNeighbors = nearestNeighborModel.getNumberOfNeighbors(); if(numberOfNeighbors < 0 || result.size() < numberOfNeighbors){ throw new InvalidAttributeException(nearestNeighborModel, PMMLAttributes.NEARESTNEIGHBORMODEL_NUMBEROFNEIGHBORS, numberOfNeighbors);
private <V extends Number> List<InstanceResult<V>> evaluateInstanceRows(ValueFactory<V> valueFactory, EvaluationContext context){ NearestNeighborModel nearestNeighborModel = getModel(); ComparisonMeasure comparisonMeasure = nearestNeighborModel.getComparisonMeasure(); List<FieldValue> values = new ArrayList<>(); KNNInputs knnInputs = nearestNeighborModel.getKNNInputs(); for(KNNInput knnInput : knnInputs){ FieldName name = knnInput.getField(); if(name == null){ throw new MissingAttributeException(knnInput, PMMLAttributes.KNNINPUT_FIELD); } FieldValue value = context.evaluate(name); values.add(value); } Measure measure = MeasureUtil.ensureMeasure(comparisonMeasure); if(measure instanceof Similarity){ return evaluateSimilarity(valueFactory, comparisonMeasure, knnInputs.getKNNInputs(), values); } else if(measure instanceof Distance){ return evaluateDistance(valueFactory, comparisonMeasure, knnInputs.getKNNInputs(), values); } else { throw new UnsupportedElementException(measure); } }
NearestNeighborModel nearestNeighborModel = new NearestNeighborModel(MiningFunction.REGRESSION, numberOfNeighbors, ModelUtil.createMiningSchema(schema.getLabel()), trainingInstances, comparisonMeasure, knnInputs) .setOutput(output);
static private Map<Integer, List<FieldValue>> loadInstanceValues(NearestNeighborModelEvaluator modelEvaluator){ NearestNeighborModel nearestNeighborModel = modelEvaluator.getModel(); Map<Integer, List<FieldValue>> result = new LinkedHashMap<>(); Table<Integer, FieldName, FieldValue> table = modelEvaluator.getValue(NearestNeighborModelEvaluator.trainingInstanceCache, createTrainingInstanceLoader(modelEvaluator)); KNNInputs knnInputs = nearestNeighborModel.getKNNInputs(); Set<Integer> rowKeys = ImmutableSortedSet.copyOf(table.rowKeySet()); for(Integer rowKey : rowKeys){ List<FieldValue> values = new ArrayList<>(); Map<FieldName, FieldValue> rowValues = table.row(rowKey); for(KNNInput knnInput : knnInputs){ FieldValue value = rowValues.get(knnInput.getField()); values.add(value); } result.put(rowKey, values); } return result; }
private <V extends Number> AffinityDistribution<V> createAffinityDistribution(List<InstanceResult<V>> instanceResults, Function<Integer, String> function, Object result){ NearestNeighborModel nearestNeighborModel = getModel(); ComparisonMeasure comparisonMeasure = nearestNeighborModel.getComparisonMeasure(); ValueMap<String, V> values = new ValueMap<>(2 * instanceResults.size()); for(InstanceResult<V> instanceResult : instanceResults){ values.put(function.apply(instanceResult.getId()), instanceResult.getValue()); } Measure measure = MeasureUtil.ensureMeasure(comparisonMeasure); if(measure instanceof Similarity){ return new AffinityDistribution<>(Classification.Type.SIMILARITY, values, result); } else if(measure instanceof Distance){ return new AffinityDistribution<>(Classification.Type.DISTANCE, values, result); } else { throw new UnsupportedElementException(measure); } }
int numberOfNeighbors = nearestNeighborModel.getNumberOfNeighbors(); FieldName instanceIdVariable = nearestNeighborModel.getInstanceIdVariable(); if(instanceIdVariable != null){ function = createIdentifierResolver(instanceIdVariable, table);
NearestNeighborModel.CategoricalScoringMethod categoricalScoringMethod = nearestNeighborModel.getCategoricalScoringMethod(); InstanceResult.Distance distance = TypeUtil.cast(InstanceResult.Distance.class, instanceResult); Value<V> weight = distance.getWeight(nearestNeighborModel.getThreshold());
private <V extends Number> V calculateContinuousTarget(ValueFactory<V> valueFactory, FieldName name, List<InstanceResult<V>> instanceResults, Table<Integer, FieldName, FieldValue> table){ NearestNeighborModel nearestNeighborModel = getModel(); NearestNeighborModel.ContinuousScoringMethod continuousScoringMethod = nearestNeighborModel.getContinuousScoringMethod(); InstanceResult.Distance distance = TypeUtil.cast(InstanceResult.Distance.class, instanceResult); Value<V> weight = distance.getWeight(nearestNeighborModel.getThreshold());
@Override public NearestNeighborModel addExtensions(org.dmg.pmml.Extension... extensions) { getExtensions().addAll(Arrays.asList(extensions)); return this; }
@Override protected <V extends Number> Map<FieldName, AffinityDistribution<V>> evaluateClustering(ValueFactory<V> valueFactory, EvaluationContext context){ NearestNeighborModel nearestNeighborModel = getModel(); Table<Integer, FieldName, FieldValue> table = getTrainingInstances(); List<InstanceResult<V>> instanceResults = evaluateInstanceRows(valueFactory, context); FieldName instanceIdVariable = nearestNeighborModel.getInstanceIdVariable(); if(instanceIdVariable == null){ throw new MissingAttributeException(nearestNeighborModel, PMMLAttributes.NEARESTNEIGHBORMODEL_INSTANCEIDVARIABLE); } Function<Integer, String> function = createIdentifierResolver(instanceIdVariable, table); AffinityDistribution<V> result = createAffinityDistribution(instanceResults, function, null); return Collections.singletonMap(getTargetName(), result); }
@Override public NearestNeighborModel addExtensions(org.dmg.pmml.Extension... extensions) { getExtensions().addAll(Arrays.asList(extensions)); return this; }
@Override public VisitorAction accept(Visitor visitor) { VisitorAction status = visitor.visit(this); if (status == VisitorAction.CONTINUE) { visitor.pushParent(this); if ((status == VisitorAction.CONTINUE)&&hasExtensions()) { status = PMMLObject.traverse(visitor, getExtensions()); } if (status == VisitorAction.CONTINUE) { status = PMMLObject.traverse(visitor, getMiningSchema(), getOutput(), getModelStats(), getModelExplanation(), getTargets(), getLocalTransformations(), getTrainingInstances(), getComparisonMeasure(), getKNNInputs(), getModelVerification()); } visitor.popParent(); } if (status == VisitorAction.TERMINATE) { return VisitorAction.TERMINATE; } return VisitorAction.CONTINUE; }
/** * Create an instance of {@link NearestNeighborModel } * */ public NearestNeighborModel createNearestNeighborModel() { return new NearestNeighborModel(); }
@Test public void inspectTypeAnnotations(){ PMML pmml = createPMML(); assertVersionRange(pmml, Version.PMML_3_0, Version.PMML_4_3); pmml.addModels(new AssociationModel(), //new ClusteringModel(), //new GeneralRegressionModel(), //new MiningModel(), new NaiveBayesModel(), new NeuralNetwork(), new RegressionModel(), new RuleSetModel(), new SequenceModel(), //new SupportVectorMachineModel(), new TextModel(), new TreeModel()); assertVersionRange(pmml, Version.PMML_3_0, Version.PMML_4_3); pmml.addModels(new TimeSeriesModel()); assertVersionRange(pmml, Version.PMML_4_0, Version.PMML_4_3); pmml.addModels(new BaselineModel(), new Scorecard(), new NearestNeighborModel()); assertVersionRange(pmml, Version.PMML_4_1, Version.PMML_4_3); pmml.addModels(new BayesianNetworkModel(), new GaussianProcessModel()); assertVersionRange(pmml, Version.PMML_4_3, Version.PMML_4_3); }