public static void main(String[] args) { ConfusionMatrix<String> confusionMatrix = new ConfusionMatrix<>(Arrays.asList("a", "b", "c")); confusionMatrix.add("a", "a", 88); confusionMatrix.add("a", "b", 10); confusionMatrix.add("b", "a", 14); confusionMatrix.add("b", "b", 40); confusionMatrix.add("b", "c", 6); confusionMatrix.add("c", "a", 18); confusionMatrix.add("c", "b", 10); confusionMatrix.add("c", "c", 12); ConfusionMatrix<String> confusionMatrix2 = new ConfusionMatrix<>(confusionMatrix); confusionMatrix2.add(confusionMatrix); System.out.println(confusionMatrix2.toHTML()); System.out.println(confusionMatrix2.toCSV()); } }
/** * Returns the number of times the given label * has actually occurred * * @param clazz the label * @return the number of times the label * actually occurred */ public int classCount(Integer clazz) { return confusion().getActualTotal(clazz); }
/** * Return the number of correct predictions according to top N value. For top N = 1 (default) this is equivalent to * the number of correct predictions * @return Number of correct top N predictions */ public int getTopNCorrectCount() { if (confusion == null) confusion = new ConfusionMatrix<>(); if (topN <= 1) { int nClasses = confusion().getClasses().size(); int countCorrect = 0; for (int i = 0; i < nClasses; i++) { countCorrect += confusion().getCount(i, i); } return countCorrect; } return topNCorrectCount; }
/** * Creates a new ConfusionMatrix initialized with the contents of another ConfusionMatrix. */ public ConfusionMatrix(ConfusionMatrix<T> other) { this(other.getClasses()); this.add(other); }
builder.append(getCount(actual, predicted)); builder.append(","); builder.append(getActualTotal(actual)); builder.append("\n"); builder.append(getPredictedTotal(predicted)); builder.append(",");
/** * False negatives: correctly rejected * * @return the total false negatives so far */ public Map<Integer, Integer> falseNegatives() { return convertToMap(falseNegatives, confusion().getClasses().size()); }
/** * Accuracy: * (TP + TN) / (P + N) * * @return the accuracy of the guesses so far */ public double accuracy() { //Accuracy: sum the counts on the diagonal of the confusion matrix, divide by total int nClasses = confusion().getClasses().size(); int countCorrect = 0; for (int i = 0; i < nClasses; i++) { countCorrect += confusion().getCount(i, i); } return countCorrect / (double) getNumRowCounter(); }
confusion = new ConfusionMatrix<>(other.confusion); } else { if (other.confusion != null) confusion().add(other.confusion);
@Override public void serialize(ConfusionMatrix<Integer> cm, JsonGenerator gen, SerializerProvider provider) throws IOException, JsonProcessingException { List<Integer> classes = cm.getClasses(); Map<Integer, Multiset<Integer>> matrix = cm.getMatrix(); Map<Integer, int[][]> m2 = new LinkedHashMap<>(); for (Integer i : matrix.keySet()) { //i = Actual class Multiset<Integer> ms = matrix.get(i); int[][] arr = new int[2][ms.size()]; int used = 0; for (Integer j : ms.elementSet()) { int count = ms.count(j); arr[0][used] = j; //j = Predicted class arr[1][used] = count; //prediction count used++; } m2.put(i, arr); } gen.writeStartObject(); gen.writeObjectField("classes", classes); gen.writeObjectField("matrix", m2); gen.writeEndObject(); } }
/** * Increments the entry specified by actual and predicted by one. */ public synchronized void add(T actual, T predicted) { add(actual, predicted, 1); }
private ConfusionMatrix<Integer> confusion() { if (confusion != null) return confusion; confusion = new ConfusionMatrix<>(); return confusion; }
/** * Computes the total number of times the class was predicted by the classifier. */ public synchronized int getPredictedTotal(T predicted) { int total = 0; for (T actual : classes) { total += getCount(actual, predicted); } return total; }
builder.append(getCount(actual, predicted)); builder.append("</td>"); builder.append(getActualTotal(actual)); builder.append("</td>"); builder.append("</tr>\n"); for (T predicted : classes) { builder.append("<td class=\"count-element\">"); builder.append(getPredictedTotal(predicted)); builder.append("</td>");
/** * False positive: wrong guess * * @return the count of the false positives */ public Map<Integer, Integer> falsePositives() { return convertToMap(falsePositives, confusion().getClasses().size()); }
int nClasses = confusion().getClasses().size(); args[1] = labelsList.get(i); for (int j = 0; j < nClasses; j++) { args[j + 2] = confusion().getCount(i, j);
ConfusionMatrix<Integer> cm = new ConfusionMatrix<>(classes); int count = iterCnt.next().asInt(); cm.add(actualClass, predictedClass, count);
/** * Adds to the confusion matrix * * @param real the actual guess * @param guess the system guess */ public void addToConfusion(Integer real, Integer guess) { confusion().add(real, guess); }
public Evaluation() { this.topN = 1; confusion = new ConfusionMatrix<>(); }
/** * True negatives: correctly rejected * * @return the total true negatives so far */ public Map<Integer, Integer> trueNegatives() { return convertToMap(trueNegatives, confusion().getClasses().size()); }
StringBuilder builder = new StringBuilder().append("\n"); StringBuilder warnings = new StringBuilder(); List<Integer> classes = confusion().getClasses(); int count = confusion().getCount(clazz, clazz2); if (count != 0) { predicted = resolveLabelForClass(clazz2); builder.append(warnings); int nClasses = confusion().getClasses().size(); DecimalFormat df = new DecimalFormat("0.0000"); double acc = accuracy();