@Test public void testDepth() { List<Integer> totals = Lists.newArrayList(); for (int i = 0; i < 1000; i++) { ChineseRestaurant x = new ChineseRestaurant(10); Multiset<Integer> counts = HashMultiset.create(); for (int j = 0; j < 100; j++) { counts.add(x.sample()); } List<Integer> tmp = Lists.newArrayList(); for (Integer k : counts.elementSet()) { tmp.add(counts.count(k)); } Collections.sort(tmp, Collections.reverseOrder()); while (totals.size() < tmp.size()) { totals.add(0); } int j = 0; for (Integer k : tmp) { totals.set(j, totals.get(j) + k); j++; } } // these are empirically derived values, not principled ones assertEquals(25000.0, (double) totals.get(0), 1000); assertEquals(24000.0, (double) totals.get(1), 1000); assertEquals(8000.0, (double) totals.get(2), 200); assertEquals(1000.0, (double) totals.get(15), 50); assertEquals(1000.0, (double) totals.get(20), 40); }
@Test public void testExtremeDiscount() { ChineseRestaurant x = new ChineseRestaurant(100, 1); Multiset<Integer> counts = HashMultiset.create(); for (int i = 0; i < 10000; i++) { counts.add(x.sample()); } assertEquals(10000, x.size()); for (int i = 0; i < 10000; i++) { assertEquals(1, x.count(i)); } }
@Test public void testGrowth() { ChineseRestaurant s0 = new ChineseRestaurant(10, 0.0); ChineseRestaurant s5 = new ChineseRestaurant(10, 0.5); ChineseRestaurant s9 = new ChineseRestaurant(10, 0.9); Set<Double> splits = ImmutableSet.of(1.0, 1.5, 2.0, 3.0, 5.0, 8.0);
/** * Samples from a lumpy distribution that acts a bit more like real data than just sampling from a normal distribution. * @param dimension The dimension of the vectors to return. * @param radius The size of the clusters we sample from. * @param alpha Controls the growth of the number of clusters. The number of clusters will be about alpha * log(samples) */ public LumpyData(int dimension, double radius, double alpha) { this.centers = new MultiNormal(dimension); this.radius = radius; cluster = new ChineseRestaurant(alpha); }