@Test public void testExtremeDiscount() { ChineseRestaurant x = new ChineseRestaurant(100, 1); Multiset<Integer> counts = HashMultiset.create(); for (int i = 0; i < 10000; i++) { counts.add(x.sample()); } assertEquals(10000, x.size()); for (int i = 0; i < 10000; i++) { assertEquals(1, x.count(i)); } }
@Test public void testGrowth() { ChineseRestaurant s0 = new ChineseRestaurant(10, 0.0); ChineseRestaurant s5 = new ChineseRestaurant(10, 0.5); ChineseRestaurant s9 = new ChineseRestaurant(10, 0.9); Set<Double> splits = ImmutableSet.of(1.0, 1.5, 2.0, 3.0, 5.0, 8.0); assertEquals(predict5, Math.log(s5.size()), 1); assertEquals(predict9, Math.log(s9.size()), 1); double x = 10.5 * Math.log(i) - s0.size(); m5.viewRow(k).assign(new double[]{Math.log(s5.size()), Math.log(i), 1}); m9.viewRow(k).assign(new double[]{Math.log(s9.size()), Math.log(i), 1}); assertEquals(0.0, (double) hapaxCount(s0) / s0.size(), 0.25); assertEquals(0.5, (double) hapaxCount(s5) / s5.size(), 0.1); assertEquals(0.9, (double) hapaxCount(s9) / s9.size(), 0.05); s0.sample(); s5.sample(); s9.sample(); i++;
@Test public void testDepth() { List<Integer> totals = Lists.newArrayList(); for (int i = 0; i < 1000; i++) { ChineseRestaurant x = new ChineseRestaurant(10); Multiset<Integer> counts = HashMultiset.create(); for (int j = 0; j < 100; j++) { counts.add(x.sample()); } List<Integer> tmp = Lists.newArrayList(); for (Integer k : counts.elementSet()) { tmp.add(counts.count(k)); } Collections.sort(tmp, Collections.reverseOrder()); while (totals.size() < tmp.size()) { totals.add(0); } int j = 0; for (Integer k : tmp) { totals.set(j, totals.get(j) + k); j++; } } // these are empirically derived values, not principled ones assertEquals(25000.0, (double) totals.get(0), 1000); assertEquals(24000.0, (double) totals.get(1), 1000); assertEquals(8000.0, (double) totals.get(2), 200); assertEquals(1000.0, (double) totals.get(15), 50); assertEquals(1000.0, (double) totals.get(20), 40); }
/** * Samples from a lumpy distribution that acts a bit more like real data than just sampling from a normal distribution. * @param dimension The dimension of the vectors to return. * @param radius The size of the clusters we sample from. * @param alpha Controls the growth of the number of clusters. The number of clusters will be about alpha * log(samples) */ public LumpyData(int dimension, double radius, double alpha) { this.centers = new MultiNormal(dimension); this.radius = radius; cluster = new ChineseRestaurant(alpha); }