private static org.apache.spark.mllib.tree.model.Node nextNode( double[] featureVector, org.apache.spark.mllib.tree.model.Node node, Split split, int featureIndex) { double featureValue = featureVector[featureIndex]; if (split.featureType().equals(FeatureType.Continuous())) { if (featureValue <= split.threshold()) { return node.leftNode().get(); } else { return node.rightNode().get(); } } else { if (split.categories().contains(featureValue)) { return node.leftNode().get(); } else { return node.rightNode().get(); } } }
long nodeCount = nodeIDCounts.get(treeNode.id()); modelNode.setRecordCount((double) nodeCount); if (treeNode.isLeaf()) { Predict prediction = treeNode.predict(); int targetEncodedValue = (int) prediction.predict(); if (classificationTask) { Split split = treeNode.split().get(); modelNode.addNodes(positiveModelNode, negativeModelNode); org.apache.spark.mllib.tree.model.Node rightTreeNode = treeNode.rightNode().get(); org.apache.spark.mllib.tree.model.Node leftTreeNode = treeNode.leftNode().get(); boolean defaultRight = nodeIDCounts.get(rightTreeNode.id()) > nodeIDCounts.get(leftTreeNode.id()); modelNode.setDefaultChild(defaultRight ? positiveModelNode.getId() : negativeModelNode.getId());
/** * @param trainPointData data to run down trees * @param model random decision forest model to count on * @return map of predictor index to the number of training examples that reached a * node whose decision is based on that feature. The index is among predictors, not all * features, since there are fewer predictors than features. That is, the index will * match the one used in the {@link RandomForestModel}. */ private static IntLongHashMap predictorExampleCounts(JavaRDD<? extends LabeledPoint> trainPointData, RandomForestModel model) { return trainPointData.mapPartitions(data -> { IntLongHashMap featureIndexCount = new IntLongHashMap(); data.forEachRemaining(datum -> { double[] featureVector = datum.features().toArray(); for (DecisionTreeModel tree : model.trees()) { org.apache.spark.mllib.tree.model.Node node = tree.topNode(); // This logic cloned from Node.predict: while (!node.isLeaf()) { Split split = node.split().get(); int featureIndex = split.feature(); // Count feature featureIndexCount.addToValue(featureIndex, 1); node = nextNode(featureVector, node, split, featureIndex); } } }); return Collections.singleton(featureIndexCount).iterator(); }).reduce(RDFUpdate::merge); }
//just a stub of how the search for a specific node might work (this is not the real implementation Node currentNode = ... if(comparator.compare(currentNode.content , toSearch) < 0) currentNode = currentNode.leftNode(); else ...