private Predicate buildPredicate(Split split, CategoricalValueEncodings categoricalValueEncodings) { if (split == null) { // Left child always applies, but is evaluated second return new True(); } int featureIndex = inputSchema.predictorToFeatureIndex(split.feature()); FieldName fieldName = FieldName.create(inputSchema.getFeatureNames().get(featureIndex)); if (split.featureType().equals(FeatureType.Categorical())) { // Note that categories in MLlib model select the *left* child but the // convention here will be that the predicate selects the *right* child // So the predicate will evaluate "not in" this set // More ugly casting @SuppressWarnings("unchecked") Collection<Double> javaCategories = (Collection<Double>) (Collection<?>) JavaConversions.seqAsJavaList(split.categories()); Set<Integer> negativeEncodings = javaCategories.stream().map(Double::intValue).collect(Collectors.toSet()); Map<Integer,String> encodingToValue = categoricalValueEncodings.getEncodingValueMap(featureIndex); List<String> negativeValues = negativeEncodings.stream().map(encodingToValue::get).collect(Collectors.toList()); String joinedValues = TextUtils.joinPMMLDelimited(negativeValues); return new SimpleSetPredicate(fieldName, SimpleSetPredicate.BooleanOperator.IS_NOT_IN, new Array(Array.Type.STRING, joinedValues)); } else { // For MLlib, left means <= threshold, so right means > return new SimplePredicate(fieldName, SimplePredicate.Operator.GREATER_THAN) .setValue(Double.toString(split.threshold())); } }
segments.add(new Segment() .setId(Integer.toString(treeID)) .setPredicate(new True()) .setModel(treeModel) .setWeight(1.0)); // No weights in MLlib impl now
Node rootNode = new Node().setId("r").setRecordCount(dummyCount).setPredicate(new True()); Node left = new Node().setId("r-").setRecordCount(halfCount).setPredicate(new True()); left.addScoreDistributions(new ScoreDistribution("apple", halfCount)); Node right = new Node().setId("r+").setRecordCount(halfCount) segments.add(new Segment() .setId(Integer.toString(i)) .setPredicate(new True()) .setModel(treeModel) .setWeight(1.0));
/** * Create an instance of {@link True } * */ public True createTrue() { return new True(); }
/** * Create an instance of {@link True } * */ public True createTrue() { return new True(); }
static public TreeModel encodeTreeModel(byte[] compressedTree, PredicateManager predicateManager, Schema schema){ Label label = new ContinuousLabel(null, DataType.DOUBLE); AtomicInteger idSequence = new AtomicInteger(1); ByteBufferWrapper buffer = new ByteBufferWrapper(compressedTree); Node root = encodeNode(new True(), idSequence, compressedTree, buffer, predicateManager, new CategoryManager(), schema); TreeModel treeModel = new TreeModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(label), root) .setMissingValueStrategy(TreeModel.MissingValueStrategy.DEFAULT_CHILD); return treeModel; }
private TreeModel encodeTreeModel(RGenericVector tree, Schema schema){ Node root = encodeNode(new True(), tree, schema); TreeModel treeModel = new TreeModel(this.miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), root) .setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT); return treeModel; }
private void makeDefault(Node node){ Predicate predicate = node.getPredicate(); CompoundPredicate compoundPredicate; if(predicate instanceof CompoundPredicate){ compoundPredicate = (CompoundPredicate)predicate; } else { compoundPredicate = new CompoundPredicate(CompoundPredicate.BooleanOperator.SURROGATE) .addPredicates(predicate); node.setPredicate(compoundPredicate); } compoundPredicate.addPredicates(new True()); }
public org.dmg.pmml.Node convert(Node node) { org.dmg.pmml.Node pmmlNode = new org.dmg.pmml.Node(); pmmlNode.setId(String.valueOf(node.getId())); pmmlNode.setDefaultChild(null); pmmlNode.setPredicate(new True()); pmmlNode.setEmbeddedModel(null); List<org.dmg.pmml.Node> childList = pmmlNode.getNodes(); org.dmg.pmml.Node left = convert(node.getLeft(), true, node.getSplit()); childList.add(left); org.dmg.pmml.Node right = convert(node.getRight(), false, node.getSplit()); childList.add(right); return pmmlNode; }
public TreeModel encodeTreeModel(PredicateManager predicateManager, Schema schema){ Node root = encodeNode(new True(), predicateManager, new CategoryManager(), 0, schema); TreeModel treeModel = new TreeModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel()), root) .setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT) .setMissingValueStrategy(TreeModel.MissingValueStrategy.DEFAULT_CHILD); return treeModel; }
public TreeModel encodeTreeModel(PredicateManager predicateManager, Schema schema){ org.dmg.pmml.tree.Node root = encodeNode(new True(), predicateManager, 0, schema); TreeModel treeModel = new TreeModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel()), root) .setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT) .setMissingValueStrategy(TreeModel.MissingValueStrategy.DEFAULT_CHILD) .setMathContext(MathContext.FLOAT); return treeModel; }
private TreeModel encodeTreeModel(MiningFunction miningFunction, RGenericVector tree, RGenericVector c_splits, Schema schema){ Node root = encodeNode(new True(), 0, tree, c_splits, new FlagManager(), new CategoryManager(), schema); TreeModel treeModel = new TreeModel(miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), root) .setSplitCharacteristic(TreeModel.SplitCharacteristic.MULTI_SPLIT); return treeModel; }
public TreeModel encodeTreeModel(PredicateManager predicateManager, Schema schema){ org.dmg.pmml.tree.Node root = encodeNode(new True(), predicateManager, 0, schema); TreeModel treeModel = new TreeModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel()), root) .setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT) .setMissingValueStrategy(TreeModel.MissingValueStrategy.DEFAULT_CHILD) .setMathContext(MathContext.FLOAT); return treeModel; }
private TreeModel encodeRegression(RGenericVector frame, RIntegerVector rowNames, RIntegerVector var, RIntegerVector n, int[][] splitInfo, RNumberVector<?> splits, RIntegerVector csplit, Schema schema){ RNumberVector<?> yval = (RNumberVector<?>)frame.getValue("yval"); ScoreEncoder scoreEncoder = new ScoreEncoder(){ @Override public Node encode(Node node, int offset){ Number score = yval.getValue(offset); Number recordCount = n.getValue(offset); node .setScore(score) .setRecordCount(recordCount.doubleValue()); return node; } }; Node root = encodeNode(new True(), 1, rowNames, var, n, splitInfo, splits, csplit, scoreEncoder, schema); TreeModel treeModel = new TreeModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel()), root); return configureTreeModel(treeModel); }
private TreeModel encodeTreeModel(MiningFunction miningFunction, ScoreEncoder scoreEncoder, RGenericVector childNodeIDs, RNumberVector<?> splitVarIDs, RNumberVector<?> splitValues, RGenericVector terminalClassCounts, Schema schema){ RNumberVector<?> leftChildIDs = (RNumberVector<?>)childNodeIDs.getValue(0); RNumberVector<?> rightChildIDs = (RNumberVector<?>)childNodeIDs.getValue(1); Node root = encodeNode(new True(), 0, scoreEncoder, leftChildIDs, rightChildIDs, splitVarIDs, splitValues, terminalClassCounts, new CategoryManager(), schema); TreeModel treeModel = new TreeModel(miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), root) .setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT); return treeModel; }
static private <M extends Model<M> & DecisionTreeModel> TreeModel encodeTreeModel(M model, PredicateManager predicateManager, MiningFunction miningFunction, ScoreEncoder scoreEncoder, Schema schema){ Node root = new Node() .setPredicate(new True()); encodeNode(root, model.rootNode(), predicateManager, new CategoryManager(), scoreEncoder, schema); TreeModel treeModel = new TreeModel(miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), root) .setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT); return treeModel; }
public TreeModel encodeTreeModel(Schema schema){ org.dmg.pmml.tree.Node root = new org.dmg.pmml.tree.Node() .setPredicate(new True()); encodeNode(root, 0, schema); TreeModel treeModel = new TreeModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel()), root) .setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT) .setMissingValueStrategy(TreeModel.MissingValueStrategy.NONE) .setMathContext(MathContext.FLOAT); return treeModel; }
private <P extends Number> TreeModel encodeTreeModel(MiningFunction miningFunction, ScoreEncoder<P> scoreEncoder, List<? extends Number> leftDaughter, List<? extends Number> rightDaughter, List<P> nodepred, List<? extends Number> bestvar, List<Double> xbestsplit, Schema schema){ RGenericVector randomForest = getObject(); Node root = encodeNode(new True(), 0, scoreEncoder, leftDaughter, rightDaughter, bestvar, xbestsplit, nodepred, new CategoryManager(), schema); TreeModel treeModel = new TreeModel(miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), root) .setMissingValueStrategy(TreeModel.MissingValueStrategy.NULL_PREDICTION) .setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT); if(this.compact){ Visitor visitor = new RandomForestCompactor(); visitor.applyTo(treeModel); } return treeModel; }
static public <E extends Estimator & HasTree> TreeModel encodeTreeModel(E estimator, PredicateManager predicateManager, ScoreDistributionManager scoreDistributionManager, MiningFunction miningFunction, Schema schema){ Tree tree = estimator.getTree(); int[] leftChildren = tree.getChildrenLeft(); int[] rightChildren = tree.getChildrenRight(); int[] features = tree.getFeature(); double[] thresholds = tree.getThreshold(); double[] values = tree.getValues(); Node root = encodeNode(new True(), predicateManager, scoreDistributionManager, 0, leftChildren, rightChildren, features, thresholds, values, miningFunction, schema); TreeModel treeModel = new TreeModel(miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), root) .setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT); ClassDictUtil.clearContent(tree); return treeModel; }