private Predicate buildPredicate(Split split, CategoricalValueEncodings categoricalValueEncodings) { if (split == null) { // Left child always applies, but is evaluated second return new True(); } int featureIndex = inputSchema.predictorToFeatureIndex(split.feature()); FieldName fieldName = FieldName.create(inputSchema.getFeatureNames().get(featureIndex)); if (split.featureType().equals(FeatureType.Categorical())) { // Note that categories in MLlib model select the *left* child but the // convention here will be that the predicate selects the *right* child // So the predicate will evaluate "not in" this set // More ugly casting @SuppressWarnings("unchecked") Collection<Double> javaCategories = (Collection<Double>) (Collection<?>) JavaConversions.seqAsJavaList(split.categories()); Set<Integer> negativeEncodings = javaCategories.stream().map(Double::intValue).collect(Collectors.toSet()); Map<Integer,String> encodingToValue = categoricalValueEncodings.getEncodingValueMap(featureIndex); List<String> negativeValues = negativeEncodings.stream().map(encodingToValue::get).collect(Collectors.toList()); String joinedValues = TextUtils.joinPMMLDelimited(negativeValues); return new SimpleSetPredicate(fieldName, SimpleSetPredicate.BooleanOperator.IS_NOT_IN, new Array(Array.Type.STRING, joinedValues)); } else { // For MLlib, left means <= threshold, so right means > return new SimplePredicate(fieldName, SimplePredicate.Operator.GREATER_THAN) .setValue(Double.toString(split.threshold())); } }
@Override public VisitorAction accept(Visitor visitor) { VisitorAction status = visitor.visit(this); if (status == VisitorAction.CONTINUE) { visitor.pushParent(this); if ((status == VisitorAction.CONTINUE)&&hasExtensions()) { status = PMMLObject.traverse(visitor, getExtensions()); } visitor.popParent(); } if (status == VisitorAction.TERMINATE) { return VisitorAction.TERMINATE; } return VisitorAction.CONTINUE; }
@Override public True addExtensions(Extension... extensions) { getExtensions().addAll(Arrays.asList(extensions)); return this; }
@Override public VisitorAction accept(Visitor visitor) { VisitorAction status = visitor.visit(this); if (status == VisitorAction.CONTINUE) { visitor.pushParent(this); if ((status == VisitorAction.CONTINUE)&&hasExtensions()) { status = PMMLObject.traverse(visitor, getExtensions()); } visitor.popParent(); } if (status == VisitorAction.TERMINATE) { return VisitorAction.TERMINATE; } return VisitorAction.CONTINUE; }
@Override public True addExtensions(Extension... extensions) { getExtensions().addAll(Arrays.asList(extensions)); return this; }
segments.add(new Segment() .setId(Integer.toString(treeID)) .setPredicate(new True()) .setModel(treeModel) .setWeight(1.0)); // No weights in MLlib impl now
Node rootNode = new Node().setId("r").setRecordCount(dummyCount).setPredicate(new True()); Node left = new Node().setId("r-").setRecordCount(halfCount).setPredicate(new True()); left.addScoreDistributions(new ScoreDistribution("apple", halfCount)); Node right = new Node().setId("r+").setRecordCount(halfCount) segments.add(new Segment() .setId(Integer.toString(i)) .setPredicate(new True()) .setModel(treeModel) .setWeight(1.0));
/** * Create an instance of {@link True } * */ public True createTrue() { return new True(); }
/** * Create an instance of {@link True } * */ public True createTrue() { return new True(); }
static public Segmentation createSegmentation(Segmentation.MultipleModelMethod multipleModelMethod, List<? extends Model> models, List<? extends Number> weights){ if((weights != null) && (models.size() != weights.size())){ throw new IllegalArgumentException(); } List<Segment> segments = new ArrayList<>(); for(int i = 0; i < models.size(); i++){ Model model = models.get(i); Number weight = (weights != null ? weights.get(i) : null); Segment segment = new Segment() .setId(String.valueOf(i + 1)) .setPredicate(new True()) .setModel(model); if(weight != null && !ValueUtil.isOne(weight)){ segment.setWeight(ValueUtil.asDouble(weight)); } segments.add(segment); } return new Segmentation(multipleModelMethod, segments); }
secondChild.setPredicate(new True()); } else
secondChild.setPredicate(new True()); } else
secondChild.setPredicate(new True()); } else
static public TreeModel encodeTreeModel(byte[] compressedTree, PredicateManager predicateManager, Schema schema){ Label label = new ContinuousLabel(null, DataType.DOUBLE); AtomicInteger idSequence = new AtomicInteger(1); ByteBufferWrapper buffer = new ByteBufferWrapper(compressedTree); Node root = encodeNode(new True(), idSequence, compressedTree, buffer, predicateManager, new CategoryManager(), schema); TreeModel treeModel = new TreeModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(label), root) .setMissingValueStrategy(TreeModel.MissingValueStrategy.DEFAULT_CHILD); return treeModel; }
private TreeModel encodeTreeModel(RGenericVector tree, Schema schema){ Node root = encodeNode(new True(), tree, schema); TreeModel treeModel = new TreeModel(this.miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), root) .setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT); return treeModel; }
private void makeDefault(Node node){ Predicate predicate = node.getPredicate(); CompoundPredicate compoundPredicate; if(predicate instanceof CompoundPredicate){ compoundPredicate = (CompoundPredicate)predicate; } else { compoundPredicate = new CompoundPredicate(CompoundPredicate.BooleanOperator.SURROGATE) .addPredicates(predicate); node.setPredicate(compoundPredicate); } compoundPredicate.addPredicates(new True()); }
secondChild.setPredicate(new True()); } else
public org.dmg.pmml.Node convert(Node node) { org.dmg.pmml.Node pmmlNode = new org.dmg.pmml.Node(); pmmlNode.setId(String.valueOf(node.getId())); pmmlNode.setDefaultChild(null); pmmlNode.setPredicate(new True()); pmmlNode.setEmbeddedModel(null); List<org.dmg.pmml.Node> childList = pmmlNode.getNodes(); org.dmg.pmml.Node left = convert(node.getLeft(), true, node.getSplit()); childList.add(left); org.dmg.pmml.Node right = convert(node.getRight(), false, node.getSplit()); childList.add(right); return pmmlNode; }
.setPredicate(new True()) .setScore(ValueUtil.formatValue(classes.get(index)));