@Test public void testConstructFromProbability() { double[] probability = {0.0, 0.125, 0.375, 0.0, 0.5, 0.0 }; CategoricalPrediction prediction = new CategoricalPrediction(probability); assertEquals(FeatureType.CATEGORICAL, prediction.getFeatureType()); assertEquals(4, prediction.getMostProbableCategoryEncoding()); assertArrayEquals(probability, prediction.getCategoryProbabilities()); }
@Test public void testUpdate2() { int[] counts = { 0, 1, 3, 0, 4, 0 }; CategoricalPrediction prediction = new CategoricalPrediction(counts); prediction.update(0, 3); prediction.update(1, 9); assertArrayEquals(new double[] { 3, 10, 3, 0, 4, 0 }, prediction.getCategoryCounts()); assertArrayEquals(new double[] {0.15, 0.5, 0.15, 0.0, 0.2, 0.0}, prediction.getCategoryProbabilities()); }
private static Prediction voteOnCategoricalFeature(List<CategoricalPrediction> predictions, double[] weights) { double[] weightedProbabilities = null; double totalWeight = 0.0; for (int i = 0; i < predictions.size(); i++) { CategoricalPrediction vote = predictions.get(i); double weight = weights[i]; totalWeight += weight; double[] categoryProbabilities = vote.getCategoryProbabilities(); if (weightedProbabilities == null) { weightedProbabilities = new double[categoryProbabilities.length]; } for (int j = 0; j < weightedProbabilities.length; j++) { weightedProbabilities[j] += categoryProbabilities[j] * weight; } } Objects.requireNonNull(weightedProbabilities, "No predictions?"); for (int j = 0; j < weightedProbabilities.length; j++) { weightedProbabilities[j] /= totalWeight; } return new CategoricalPrediction(weightedProbabilities); }
@Test public void testCategoricalVoteWeighted() { List<CategoricalPrediction> predictions = Arrays.asList( new CategoricalPrediction(new int[]{0, 1, 2}), new CategoricalPrediction(new int[]{6, 2, 0}), new CategoricalPrediction(new int[]{0, 2, 0}) ); double[] weights = {1.0, 10.0, 1.0}; CategoricalPrediction vote = (CategoricalPrediction) WeightedPrediction.voteOnFeature(predictions, weights); assertEquals(FeatureType.CATEGORICAL, vote.getFeatureType()); assertEquals(0, vote.getMostProbableCategoryEncoding()); }
@Override public String predict(String[] example) { Prediction prediction = makePrediction(example); if (inputSchema.isClassification()) { int targetIndex = inputSchema.getTargetFeatureIndex(); Map<Integer,String> targetEncodingName = encodings.getEncodingValueMap(targetIndex); int mostProbable = ((CategoricalPrediction) prediction).getMostProbableCategoryEncoding(); return targetEncodingName.get(mostProbable); } else { double score = ((NumericPrediction) prediction).getPrediction(); return Double.toString(score); } }
@Test public void testEquals() { int[] counts = { 0, 1, 3, 0, 4, 0 }; CategoricalPrediction prediction = new CategoricalPrediction(counts); assertNotEquals(prediction, new CategoricalPrediction(new int[] { 1, 2, 4, 5, 6, 7 })); }
@Test public void testHashCode() { int[] counts = { 0, 1, 3, 0, 4, 0 }; CategoricalPrediction prediction = new CategoricalPrediction(counts); assertEquals(566115137, prediction.hashCode()); }
@Test public void testToString() { int[] counts = { 0, 1, 3, 0, 4, 0 }; CategoricalPrediction prediction = new CategoricalPrediction(counts); assertEquals(":[0.0, 0.125, 0.375, 0.0, 0.5, 0.0]", prediction.toString()); }
@Override public void update(Example train) { CategoricalFeature target = (CategoricalFeature) train.getTarget(); update(target.getEncoding(), 1); }
@GET @Path("{datum}") @Produces({MediaType.TEXT_PLAIN, "text/csv", MediaType.APPLICATION_JSON}) public List<IDValue> get(@PathParam("datum") String datum) throws OryxServingException { check(datum != null && !datum.isEmpty(), "Missing input data"); RDFServingModel model = (RDFServingModel) getServingModel(); InputSchema inputSchema = model.getInputSchema(); check(inputSchema.isClassification(), "Only applicable for classification"); Prediction prediction = model.makePrediction(TextUtils.parseDelimited(datum, ',')); double[] probabilities = ((CategoricalPrediction) prediction).getCategoryProbabilities(); int targetIndex = inputSchema.getTargetFeatureIndex(); CategoricalValueEncodings valueEncodings = model.getEncodings(); Map<Integer,String> targetEncodingName = valueEncodings.getEncodingValueMap(targetIndex); List<IDValue> result = new ArrayList<>(probabilities.length); for (int i = 0; i < probabilities.length; i++) { result.add(new IDValue(targetEncodingName.get(i), probabilities[i])); } return result; }
assertEquals(2, leftPrediction.getCategoryCounts()[0]); assertEquals(5, leftPrediction.getCategoryCounts()[1]); assertEquals(3, rightPrediction.getCategoryCounts()[0]); assertEquals(4, rightPrediction.getCategoryCounts()[1]);
@Test public void testCategoricalVote() { List<CategoricalPrediction> predictions = Arrays.asList( new CategoricalPrediction(new int[]{0, 1, 2}), new CategoricalPrediction(new int[]{6, 2, 0}), new CategoricalPrediction(new int[]{0, 2, 0}) ); double[] weights = {1.0, 1.0, 1.0}; CategoricalPrediction vote = (CategoricalPrediction) WeightedPrediction.voteOnFeature(predictions, weights); assertEquals(FeatureType.CATEGORICAL, vote.getFeatureType()); assertEquals(1, vote.getMostProbableCategoryEncoding()); }
private static Prediction voteOnCategoricalFeature(List<CategoricalPrediction> predictions, double[] weights) { double[] weightedProbabilities = null; double totalWeight = 0.0; for (int i = 0; i < predictions.size(); i++) { CategoricalPrediction vote = predictions.get(i); double weight = weights[i]; totalWeight += weight; double[] categoryProbabilities = vote.getCategoryProbabilities(); if (weightedProbabilities == null) { weightedProbabilities = new double[categoryProbabilities.length]; } for (int j = 0; j < weightedProbabilities.length; j++) { weightedProbabilities[j] += categoryProbabilities[j] * weight; } } Objects.requireNonNull(weightedProbabilities, "No predictions?"); for (int j = 0; j < weightedProbabilities.length; j++) { weightedProbabilities[j] /= totalWeight; } return new CategoricalPrediction(weightedProbabilities); }
static double accuracy(DecisionForest forest, JavaRDD<Example> examples) { long total = examples.count(); if (total == 0) { return 0.0; } long correct = examples.filter(example -> { CategoricalPrediction prediction = (CategoricalPrediction) forest.predict(example); CategoricalFeature target = (CategoricalFeature) example.getTarget(); return prediction.getMostProbableCategoryEncoding() == target.getEncoding(); }).count(); return (double) correct / total; }
public static RDFServingModel buildTestModel() { Map<Integer,Collection<String>> distinctValues = new HashMap<>(); distinctValues.put(0, Arrays.asList("A", "B", "C")); distinctValues.put(2, Arrays.asList("X", "Y", "Z")); CategoricalValueEncodings encodings = new CategoricalValueEncodings(distinctValues); TerminalNode left1 = new TerminalNode("r-", new CategoricalPrediction(new int[] { 1, 2, 3 })); TerminalNode right1 = new TerminalNode("r+", new CategoricalPrediction(new int[] { 10, 30, 50 })); BitSet activeCategories = new BitSet(2); activeCategories.set(1); Decision decision1 = new CategoricalDecision(0, activeCategories, true); TreeNode root1 = new DecisionNode("r", decision1, left1, right1); TerminalNode left2 = new TerminalNode("r-", new CategoricalPrediction(new int[] { 100, 400, 900 })); TerminalNode right2 = new TerminalNode("r+", new CategoricalPrediction(new int[] { 1000, 10000, 100000 })); Decision decision2 = new NumericDecision(1, -3.0, false); TreeNode root2 = new DecisionNode("r", decision2, left2, right2); DecisionTree tree1 = new DecisionTree(root1); DecisionTree tree2 = new DecisionTree(root2); DecisionTree[] trees = { tree1, tree2 }; double[] weights = { 1.0, 2.0 }; double[] featureImportances = { 0.1, 0.3 }; DecisionForest forest = new DecisionForest(trees, weights, featureImportances); Map<String,Object> overlayConfig = new HashMap<>(); overlayConfig.put("oryx.input-schema.num-features", 3); overlayConfig.put("oryx.input-schema.categorical-features", "[\"0\",\"2\"]"); overlayConfig.put("oryx.input-schema.target-feature", "\"2\""); Config config = ConfigUtils.overlayOn(overlayConfig, ConfigUtils.getDefault()); InputSchema inputSchema = new InputSchema(config); return new RDFServingModel(forest, encodings, inputSchema); }
@SuppressWarnings("unchecked") Map<String,Integer> counts = (Map<String,Integer>) update.get(2); // JSON map keys are always Strings counts.forEach((encoding, count) -> predictionToUpdate.update(Integer.parseInt(encoding), count)); } else { TerminalNode nodeToUpdate = (TerminalNode) forest.getTrees()[treeID].findByID(nodeID);
@Test public void testUpdate() { int[] counts = { 0, 1, 3, 0, 4, 0 }; CategoricalPrediction prediction = new CategoricalPrediction(counts); Example example = new Example(CategoricalFeature.forEncoding(2)); // Yes, called twice prediction.update(example); prediction.update(example); assertEquals(2, prediction.getMostProbableCategoryEncoding()); counts[2] += 2; assertArrayEquals(toDoubles(counts), prediction.getCategoryCounts()); assertArrayEquals(new double[] {0.0, 0.1, 0.5, 0.0, 0.4, 0.0}, prediction.getCategoryProbabilities()); }
@Test public void testConstruct() { int[] counts = { 0, 1, 3, 0, 4, 0 }; CategoricalPrediction prediction = new CategoricalPrediction(counts); assertEquals(FeatureType.CATEGORICAL, prediction.getFeatureType()); assertEquals(4, prediction.getMostProbableCategoryEncoding()); assertArrayEquals(toDoubles(counts), prediction.getCategoryCounts()); assertArrayEquals(new double[] {0.0, 0.125, 0.375, 0.0, 0.5, 0.0}, prediction.getCategoryProbabilities()); }
boolean expectedPositive = f1 == 1 && f2 == 1 && f3 == 1; assertEquals(targetEncoding.get(Boolean.toString(expectedPositive)).intValue(), prediction.getMostProbableCategoryEncoding());