private Predicate buildPredicate(Split split, CategoricalValueEncodings categoricalValueEncodings) { if (split == null) { // Left child always applies, but is evaluated second return new True(); } int featureIndex = inputSchema.predictorToFeatureIndex(split.feature()); FieldName fieldName = FieldName.create(inputSchema.getFeatureNames().get(featureIndex)); if (split.featureType().equals(FeatureType.Categorical())) { // Note that categories in MLlib model select the *left* child but the // convention here will be that the predicate selects the *right* child // So the predicate will evaluate "not in" this set // More ugly casting @SuppressWarnings("unchecked") Collection<Double> javaCategories = (Collection<Double>) (Collection<?>) JavaConversions.seqAsJavaList(split.categories()); Set<Integer> negativeEncodings = javaCategories.stream().map(Double::intValue).collect(Collectors.toSet()); Map<Integer,String> encodingToValue = categoricalValueEncodings.getEncodingValueMap(featureIndex); List<String> negativeValues = negativeEncodings.stream().map(encodingToValue::get).collect(Collectors.toList()); String joinedValues = TextUtils.joinPMMLDelimited(negativeValues); return new SimpleSetPredicate(fieldName, SimpleSetPredicate.BooleanOperator.IS_NOT_IN, new Array(Array.Type.STRING, joinedValues)); } else { // For MLlib, left means <= threshold, so right means > return new SimplePredicate(fieldName, SimplePredicate.Operator.GREATER_THAN) .setValue(Double.toString(split.threshold())); } }
SimpleSetPredicate.BooleanOperator operator = simpleSetPredicate.getBooleanOperator(); Preconditions.checkArgument( operator == SimpleSetPredicate.BooleanOperator.IS_IN || operator == SimpleSetPredicate.BooleanOperator.IS_NOT_IN); int featureNumber = featureNames.indexOf(simpleSetPredicate.getField().getValue()); Map<String,Integer> valueEncodingMap = categoricalValueEncodings.getValueEncodingMap(featureNumber); String[] categories = TextUtils.parseDelimited(simpleSetPredicate.getArray().getValue(), ' '); BitSet activeCategories = new BitSet(valueEncodingMap.size()); if (operator == SimpleSetPredicate.BooleanOperator.IS_IN) {
SimpleSetPredicate p = new SimpleSetPredicate(); Set<Short> childCategories = split.getLeftOrRightCategories(); p.setField(new FieldName(CommonUtils.getSimpleColumnName(columnConfig.getColumnName()))); StringBuilder arrayStr = new StringBuilder(); List<String> valueList = treeModel.getCategoricalColumnNameNames().get(columnConfig.getColumnNum()); p.setArray(array); if(isLeft) { if(split.isLeft()) { p.setBooleanOperator(SimpleSetPredicate.BooleanOperator.fromValue("isIn")); } else { p.setBooleanOperator(SimpleSetPredicate.BooleanOperator.fromValue("isNotIn")); p.setBooleanOperator(SimpleSetPredicate.BooleanOperator.fromValue("isNotIn")); } else { p.setBooleanOperator(SimpleSetPredicate.BooleanOperator.fromValue("isIn"));
@Override public VisitorAction accept(Visitor visitor) { VisitorAction status = visitor.visit(this); if (status == VisitorAction.CONTINUE) { visitor.pushParent(this); if ((status == VisitorAction.CONTINUE)&&hasExtensions()) { status = PMMLObject.traverse(visitor, getExtensions()); } if (status == VisitorAction.CONTINUE) { status = PMMLObject.traverse(visitor, getArray()); } visitor.popParent(); } if (status == VisitorAction.TERMINATE) { return VisitorAction.TERMINATE; } return VisitorAction.CONTINUE; }
@Override public VisitorAction visit(SimpleSetPredicate simpleSetPredicate){ process(simpleSetPredicate.getField()); return super.visit(simpleSetPredicate); }
@Override public SimpleSetPredicate addExtensions(Extension... extensions) { getExtensions().addAll(Arrays.asList(extensions)); return this; }
@Override public VisitorAction accept(Visitor visitor) { VisitorAction status = visitor.visit(this); if (status == VisitorAction.CONTINUE) { visitor.pushParent(this); if ((status == VisitorAction.CONTINUE)&&hasExtensions()) { status = PMMLObject.traverse(visitor, getExtensions()); } if (status == VisitorAction.CONTINUE) { status = PMMLObject.traverse(visitor, getArray()); } visitor.popParent(); } if (status == VisitorAction.TERMINATE) { return VisitorAction.TERMINATE; } return VisitorAction.CONTINUE; }
@Override public VisitorAction visit(SimpleSetPredicate simpleSetPredicate){ process(simpleSetPredicate.getField()); return super.visit(simpleSetPredicate); }
@Override public SimpleSetPredicate addExtensions(Extension... extensions) { getExtensions().addAll(Arrays.asList(extensions)); return this; }
@Override public ElementKey createKey(SimpleSetPredicate simpleSetPredicate){ Array array = simpleSetPredicate.getArray(); Object[] content = {simpleSetPredicate.getField(), simpleSetPredicate.getBooleanOperator(), ArrayUtil.getContent(array)}; return new ElementKey(content); } };
left.addScoreDistributions(new ScoreDistribution("apple", halfCount)); Node right = new Node().setId("r+").setRecordCount(halfCount) .setPredicate(new SimpleSetPredicate(FieldName.create("color"), SimpleSetPredicate.BooleanOperator.IS_NOT_IN, new Array(Array.Type.STRING, "red")));
@Override public ElementKey createKey(SimpleSetPredicate simpleSetPredicate){ Array array = simpleSetPredicate.getArray(); Object[] content = {simpleSetPredicate.getField(), simpleSetPredicate.getBooleanOperator(), ArrayUtil.getContent(array)}; return new ElementKey(content); } };
static private SimpleSetPredicate createSimpleSetPredicate(FieldName field, SimpleSetPredicate.BooleanOperator booleanOperator, Array array){ SimpleSetPredicate simpleSetPredicate = new SimpleSetPredicate(field, booleanOperator, array); return simpleSetPredicate; }
@Override public ElementKey createKey(SimpleSetPredicate simpleSetPredicate){ Array array = simpleSetPredicate.getArray(); Object[] content = {simpleSetPredicate.getField(), simpleSetPredicate.getBooleanOperator(), ArrayUtil.getContent(array)}; return new ElementKey(content); } };
/** * Create an instance of {@link SimpleSetPredicate } * */ public SimpleSetPredicate createSimpleSetPredicate() { return new SimpleSetPredicate(); }
private Predicate transform(SimpleSetPredicate simpleSetPredicate){ Array array = simpleSetPredicate.getArray(); String value = array.getValue(); List<String> tokens = ArrayUtil.parse(value, true); if(tokens.size() != 1){ return simpleSetPredicate; } value = tokens.get(0); SimpleSetPredicate.BooleanOperator booleanOperator = simpleSetPredicate.getBooleanOperator(); switch(booleanOperator){ case IS_IN: return createSimplePredicate(simpleSetPredicate.getField(), SimplePredicate.Operator.EQUAL, value); case IS_NOT_IN: return createSimplePredicate(simpleSetPredicate.getField(), SimplePredicate.Operator.NOT_EQUAL, value); default: break; } return simpleSetPredicate; }
/** * Create an instance of {@link SimpleSetPredicate } * */ public SimpleSetPredicate createSimpleSetPredicate() { return new SimpleSetPredicate(); }
static public Boolean evaluateSimpleSetPredicate(SimpleSetPredicate simpleSetPredicate, EvaluationContext context){ FieldName name = simpleSetPredicate.getField(); if(name == null){ throw new MissingAttributeException(simpleSetPredicate, PMMLAttributes.SIMPLESETPREDICATE_FIELD); } SimpleSetPredicate.BooleanOperator booleanOperator = simpleSetPredicate.getBooleanOperator(); if(booleanOperator == null){ throw new MissingAttributeException(simpleSetPredicate, PMMLAttributes.SIMPLESETPREDICATE_BOOLEANOPERATOR); } FieldValue value = context.evaluate(name); if(Objects.equals(FieldValues.MISSING_VALUE, value)){ return null; } Array array = simpleSetPredicate.getArray(); if(array == null){ throw new MissingElementException(simpleSetPredicate, PMMLElements.SIMPLESETPREDICATE_ARRAY); } switch(booleanOperator){ case IS_IN: return value.isIn(simpleSetPredicate); case IS_NOT_IN: return !value.isIn(simpleSetPredicate); default: throw new UnsupportedAttributeException(simpleSetPredicate, booleanOperator); } }
.setPredicate(new SimpleSetPredicate());
SimpleSetPredicate.BooleanOperator operator = simpleSetPredicate.getBooleanOperator(); Preconditions.checkArgument( operator == SimpleSetPredicate.BooleanOperator.IS_IN || operator == SimpleSetPredicate.BooleanOperator.IS_NOT_IN); int featureNumber = featureNames.indexOf(simpleSetPredicate.getField().getValue()); Map<String,Integer> valueEncodingMap = categoricalValueEncodings.getValueEncodingMap(featureNumber); String[] categories = TextUtils.parseDelimited(simpleSetPredicate.getArray().getValue(), ' '); BitSet activeCategories = new BitSet(valueEncodingMap.size()); if (operator == SimpleSetPredicate.BooleanOperator.IS_IN) {