int featureIndex = split.feature(); node = nextNode(featureVector, node, split, featureIndex);
private Predicate buildPredicate(Split split, CategoricalValueEncodings categoricalValueEncodings) { if (split == null) { // Left child always applies, but is evaluated second return new True(); } int featureIndex = inputSchema.predictorToFeatureIndex(split.feature()); FieldName fieldName = FieldName.create(inputSchema.getFeatureNames().get(featureIndex)); if (split.featureType().equals(FeatureType.Categorical())) { // Note that categories in MLlib model select the *left* child but the // convention here will be that the predicate selects the *right* child // So the predicate will evaluate "not in" this set // More ugly casting @SuppressWarnings("unchecked") Collection<Double> javaCategories = (Collection<Double>) (Collection<?>) JavaConversions.seqAsJavaList(split.categories()); Set<Integer> negativeEncodings = javaCategories.stream().map(Double::intValue).collect(Collectors.toSet()); Map<Integer,String> encodingToValue = categoricalValueEncodings.getEncodingValueMap(featureIndex); List<String> negativeValues = negativeEncodings.stream().map(encodingToValue::get).collect(Collectors.toList()); String joinedValues = TextUtils.joinPMMLDelimited(negativeValues); return new SimpleSetPredicate(fieldName, SimpleSetPredicate.BooleanOperator.IS_NOT_IN, new Array(Array.Type.STRING, joinedValues)); } else { // For MLlib, left means <= threshold, so right means > return new SimplePredicate(fieldName, SimplePredicate.Operator.GREATER_THAN) .setValue(Double.toString(split.threshold())); } }
/** * @param trainPointData data to run down trees * @param model random decision forest model to count on * @return map of predictor index to the number of training examples that reached a * node whose decision is based on that feature. The index is among predictors, not all * features, since there are fewer predictors than features. That is, the index will * match the one used in the {@link RandomForestModel}. */ private static IntLongHashMap predictorExampleCounts(JavaRDD<? extends LabeledPoint> trainPointData, RandomForestModel model) { return trainPointData.mapPartitions(data -> { IntLongHashMap featureIndexCount = new IntLongHashMap(); data.forEachRemaining(datum -> { double[] featureVector = datum.features().toArray(); for (DecisionTreeModel tree : model.trees()) { org.apache.spark.mllib.tree.model.Node node = tree.topNode(); // This logic cloned from Node.predict: while (!node.isLeaf()) { Split split = node.split().get(); int featureIndex = split.feature(); // Count feature featureIndexCount.addToValue(featureIndex, 1); node = nextNode(featureVector, node, split, featureIndex); } } }); return Collections.singleton(featureIndexCount).iterator(); }).reduce(RDFUpdate::merge); }