Preconditions.checkState(classificationTask == inputSchema.isClassification()); Node root = new Node(); root.setId("r"); modelNode.setPredicate(predicate); modelNode.setRecordCount((double) nodeCount); modelNode.addScoreDistributions(distribution); modelNode.setScore(Double.toString(targetEncodedValue)); Node positiveModelNode = new Node().setId(modelNode.getId() + '+'); Node negativeModelNode = new Node().setId(modelNode.getId() + '-'); modelNode.addNodes(positiveModelNode, negativeModelNode); modelNode.setDefaultChild(defaultRight ? positiveModelNode.getId() : negativeModelNode.getId());
String id = root.getId(); List<Node> children = root.getNodes(); if (children.isEmpty()) { Collection<ScoreDistribution> scoreDistributions = root.getScoreDistributions(); Prediction prediction; if (scoreDistributions != null && !scoreDistributions.isEmpty()) { prediction = new NumericPrediction(Double.parseDouble(root.getScore()), (int) Math.round(root.getRecordCount())); Node negativeLeftChild; Node positiveRightChild; if (child1.getPredicate() instanceof True) { negativeLeftChild = child1; positiveRightChild = child2; } else { Preconditions.checkArgument(child2.getPredicate() instanceof True); negativeLeftChild = child2; positiveRightChild = child1; Predicate predicate = positiveRightChild.getPredicate(); boolean defaultDecision = positiveRightChild.getId().equals(root.getDefaultChild());
private static void checkNode(Node node) { assertNotNull(node.getId()); List<ScoreDistribution> scoreDists = node.getScoreDistributions(); int numDists = scoreDists.size(); if (numDists == 0) { List<Node> children = node.getNodes(); assertEquals(2, children.size()); Node rightChild = children.get(0); Node leftChild = children.get(1); assertInstanceOf(leftChild.getPredicate(), True.class); assertEquals(node.getRecordCount().doubleValue(), leftChild.getRecordCount() + rightChild.getRecordCount()); assertEquals(node.getId() + "+", rightChild.getId()); assertEquals(node.getId() + "-", leftChild.getId()); checkNode(rightChild); checkNode(leftChild);
@Override public void enterNode(Node node){ String id = node.getId(); Object score = node.getScore(); String defaultChild = node.getDefaultChild(); if(node.hasNodes()){ List<Node> children = node.getNodes(); Node secondChild = children.get(1); if((defaultChild).equals(firstChild.getId())){ children = swapChildren(node); } else if((defaultChild).equals(secondChild.getId())){ node.setDefaultChild(null); secondChild.setPredicate(new True()); } else node.setId(null);
value.setId(node.getId()); value.setScore(node.getScore()); value.setRecordCount(node.getRecordCount()); value.setDefaultChild(node.getDefaultChild()); if(node.hasExtensions()){ (value.getExtensions()).addAll(node.getExtensions()); value.setPredicate(node.getPredicate()); value.setPartition(node.getPartition()); if(node.hasScoreDistributions()){ (value.getScoreDistributions()).addAll(node.getScoreDistributions()); if(node.hasNodes()){ (value.getNodes()).addAll(node.getNodes()); value.setEmbeddedModel(node.getEmbeddedModel());
@Override public void enterNode(Node node){ String id = node.getId(); Object score = node.getScore(); if(node.hasNodes()){ List<Node> children = node.getNodes(); Node secondChild = children.get(1); Predicate firstPredicate = firstChild.getPredicate(); Predicate secondPredicate = secondChild.getPredicate(); secondChild.setPredicate(new True()); } else node.setId(null);
static public ComplexNode toComplexNode(Node node){ ComplexNode result = new ComplexNode() .setId(node.getId()) .setScore(node.getScore()) .setRecordCount(node.getRecordCount()) .setDefaultChild(node.getDefaultChild()) .setPredicate(node.getPredicate()); if(node.hasNodes()){ (result.getNodes()).addAll(node.getNodes()); } // End if if(node.hasScoreDistributions()){ (result.getScoreDistributions()).addAll(node.getScoreDistributions()); } return result; } }
.setId(id) .setDefaultChild(defaultLeft ? leftChild.getId() : rightChild.getId()) .setPredicate(predicate) .addNodes(leftChild, rightChild);
Node rootNode = new Node().setId("r").setRecordCount(dummyCount).setPredicate(new True()); Node left = new Node() .setId("r-") .setRecordCount(halfCount) .setPredicate(new True()) .setScore("-2.0"); Node right = new Node().setId("r+").setRecordCount(halfCount) .setPredicate(new SimplePredicate(FieldName.create("foo"), SimplePredicate.Operator.GREATER_THAN).setValue("3.14")) .setScore("2.0"); rootNode.addNodes(right, left);
private void encodeNode(org.dmg.pmml.tree.Node parent, int index, Schema schema){ parent.setId(String.valueOf(index + 1)); Node node = allNodes.get(index); if(!node.isLeaf()){ int splitIndex = node.getFeatureIndex(); Feature feature = schema.getFeature(splitIndex); org.dmg.pmml.tree.Node leftChild = new org.dmg.pmml.tree.Node() .setPredicate(encodePredicate(feature, node, true)); encodeNode(leftChild, node.getLeftChild().getId(), schema); org.dmg.pmml.tree.Node rightChild = new org.dmg.pmml.tree.Node() .setPredicate(encodePredicate(feature, node, false)); encodeNode(rightChild, node.getRightChild().getId(), schema); parent.addNodes(leftChild, rightChild); boolean defaultLeft = false; parent.setDefaultChild(defaultLeft ? leftChild.getId() : rightChild.getId()); } else { float value = (float)node.getValue(); parent.setScore(ValueUtil.formatValue(value)); } }
Node rootNode = new Node().setId("r").setRecordCount(dummyCount).setPredicate(new True()); Node left = new Node().setId("r-").setRecordCount(halfCount).setPredicate(new True()); left.addScoreDistributions(new ScoreDistribution("apple", halfCount)); Node right = new Node().setId("r+").setRecordCount(halfCount) .setPredicate(new SimpleSetPredicate(FieldName.create("color"), SimpleSetPredicate.BooleanOperator.IS_NOT_IN, new Array(Array.Type.STRING, "red"))); right.addScoreDistributions(new ScoreDistribution("banana", halfCount)); rootNode.addNodes(right, left);
.setId(id) .setDefaultChild(defaultLeft ? leftChild.getId() : rightChild.getId()) .setRecordCount((double)this.internal_count_[index]) .setPredicate(predicate) .addNodes(leftChild, rightChild); .setScore(this.leaf_value_[index]) .setRecordCount((double)this.leaf_count_[index]) .setPredicate(predicate);
@Override public Node unmarshal(ComplexNode value){ if(value.getRecordCount() != null){ return value; } // End if if(value.hasExtensions() || (value.getPartition() != null) || value.hasScoreDistributions() || (value.getEmbeddedModel() != null)){ return value; } Node node; if(value.hasNodes()){ node = new BranchNode() .setId(value.getId()) .setDefaultChild(value.getDefaultChild()); (node.getNodes()).addAll(value.getNodes()); } else { node = new LeafNode() .setId(value.getId()); } node .setScore(value.getScore()) .setPredicate(value.getPredicate()); return node; }
private Trail handleDefaultChild(Trail trail, Node node, EvaluationContext context){ // "The defaultChild missing value strategy requires the presence of the defaultChild attribute in every non-leaf Node" String defaultChild = node.getDefaultChild(); if(defaultChild == null){ throw new MissingAttributeException(node, PMMLAttributes.NODE_DEFAULTCHILD); } trail.addMissingLevel(); List<Node> children = node.getNodes(); for(int i = 0, max = children.size(); i < max; i++){ Node child = children.get(i); String id = child.getId(); if(id != null && (id).equals(defaultChild)){ // The predicate of the referenced Node is not evaluated return handleTrue(trail, child, context); } } // "Only Nodes which are immediate children of the respective Node can be referenced" throw new InvalidAttributeException(node, PMMLAttributes.NODE_DEFAULTCHILD, defaultChild); }
@Override public VisitorAction visit(Node node){ if(node.getScore() != null){ double nodeDepth = 0d; Deque<PMMLObject> parents = getParents(); for(PMMLObject parent : parents){ if(!(parent instanceof Node)){ break; } nodeDepth++; } double nodeSample = this.nodeSamples[Integer.parseInt(node.getId())]; double averagePathLength = (corrected ? correctedAveragePathLength(nodeSample) : averagePathLength(nodeSample)); node.setScore(nodeDepth + averagePathLength); } return super.visit(node); } };
static private Iterator<Node> getChildren(Node node){ Predicate predicate = node.getPredicate(); if(node.hasNodes()){ List<Node> children = node.getNodes(); Predicate childPredicate = child.getPredicate();
.setPredicate(new True()); List<Node> nodes = node1a.getNodes(); assertEquals(node1a.getId(), jaxbNode1a.getId()); List<Node> jaxbNodes = jaxbNode1a.getNodes(); assertEquals(node2a.getId(), jaxbNode2a.getId()); assertTrue(jaxbNode2a.hasExtensions()); assertEquals(node2b.getId(), jaxbNode2b.getId());
@Override public VisitorAction visit(Node node){ MathContext mathContext = this.mathContext; if(mathContext != null && node.hasScore()){ Object score = node.getScore(); if(score instanceof String){ String stringScore = (String)score; try { switch(mathContext){ case DOUBLE: node.setScore(Double.parseDouble(stringScore)); break; case FLOAT: node.setScore(Float.parseFloat(stringScore)); break; default: break; } } catch(NumberFormatException nfe){ // Ignored } } } return super.visit(node); } }
Node leftChild = new Node() .setPredicate(leftPredicate); Node rightChild = new Node() .setPredicate(rightPredicate); node.addNodes(leftChild, rightChild); } else