private static void dissect(Dictionary newsGroups, AdaptiveLogisticRegression learningAlgorithm, Iterable<File> files) throws IOException { CrossFoldLearner model = learningAlgorithm.getBest().getPayload().getLearner(); model.close(); Map<String, Set<Integer>> traceDictionary = Maps.newTreeMap(); ModelDissector md = new ModelDissector(); encoder.setTraceDictionary(traceDictionary); bias.setTraceDictionary(traceDictionary); for (File file : permute(files, rand).subList(0, 500)) { traceDictionary.clear(); Vector v = encodeFeatureVector(file); md.update(v, traceDictionary, model); } List<String> ngNames = Lists.newArrayList(newsGroups.values()); List<ModelDissector.Weight> weights = md.summary(100); for (ModelDissector.Weight w : weights) { System.out.printf("%s\t%.1f\t%s\t%.1f\t%s\t%.1f\t%s\n", w.getFeature(), w.getWeight(), ngNames.get(w.getMaxImpact() + 1), w.getCategory(1), w.getWeight(1), w.getCategory(2), w.getWeight(2)); } }
@Test public void crossFoldLearnerRoundTrip() throws IOException { CrossFoldLearner learner = new CrossFoldLearner(5, 2, 5, new L1()); train(learner, 100); CrossFoldLearner olr3 = roundTrip(learner, CrossFoldLearner.class); double auc1 = learner.auc(); assertTrue(auc1 > 0.85); assertEquals(auc1, learner.auc(), 1.0e-6); assertEquals(auc1, olr3.auc(), 1.0e-6); train(learner, 100); train(learner, 100); train(olr3, 100); assertEquals(learner.auc(), learner.auc(), 0.02); assertEquals(learner.auc(), olr3.auc(), 0.02); double auc2 = learner.auc(); assertTrue(auc2 > auc1); learner.close(); olr3.close(); }