@Test public void testInsert() { Random rand = RandomUtils.getRandom(); Multinomial<Integer> table = new Multinomial<>(); double[] p = new double[10]; for (int i = 0; i < 10; i++) { p[i] = rand.nextDouble(); table.add(i, p[i]); } checkSelfConsistent(table); for (int i = 0; i < 10; i++) { assertEquals(p[i], table.getWeight(i), 0); } }
@Test public void testEvenSplit() { Multiset<String> stuff = HashMultiset.create(); for (int i = 0; i < 5; i++) { stuff.add(String.valueOf(i)); } Multinomial<String> s = new Multinomial<>(stuff); double EPSILON = 1.0e-15; Multiset<String> cnt = HashMultiset.create(); for (int i = 0; i < 5; i++) { cnt.add(s.sample(i * 0.2)); cnt.add(s.sample(i * 0.2 + EPSILON)); cnt.add(s.sample((i + 1) * 0.2 - EPSILON)); } assertEquals(5, cnt.elementSet().size()); for (String v : cnt.elementSet()) { assertEquals(3, cnt.count(v), 1.01); } assertTrue(cnt.contains(s.sample(1))); assertEquals(s.sample(1 - EPSILON), s.sample(1)); }
@Test public void testSingleton() { Multiset<String> oneThing = HashMultiset.create(); oneThing.add("one"); Multinomial<String> s = new Multinomial<>(oneThing); assertEquals("one", s.sample(0)); assertEquals("one", s.sample(0.1)); assertEquals("one", s.sample(1)); }
cnt.add(s0.sample(p1 - EPSILON)); assertEquals(s0.sample(p0), s1.sample(p0)); assertEquals(s0.sample(p0 + EPSILON), s1.sample(p0 + EPSILON)); assertEquals(s0.sample(p1 - EPSILON), s1.sample(p1 - EPSILON)); assertEquals(s0.sample(p0), s2.sample(p0)); assertEquals(s0.sample(p0 + EPSILON), s2.sample(p0 + EPSILON)); assertEquals(s0.sample(p1 - EPSILON), s2.sample(p1 - EPSILON)); assertEquals(s0.sample(0), s1.sample(0)); assertEquals(s0.sample(0 + EPSILON), s1.sample(0 + EPSILON)); assertEquals(s0.sample(1 - EPSILON), s1.sample(1 - EPSILON)); assertEquals(s0.sample(1), s1.sample(1)); assertEquals(s0.sample(0), s2.sample(0)); assertEquals(s0.sample(0 + EPSILON), s2.sample(0 + EPSILON)); assertEquals(s0.sample(1 - EPSILON), s2.sample(1 - EPSILON)); assertEquals(s0.sample(1), s2.sample(1)); assertEquals(5, cnt.elementSet().size()); assertTrue(Math.abs(ref.get(v) - cnt.count(v)) <= 2); assertTrue(cnt.contains(s0.sample(1))); assertEquals(s0.sample(1 - EPSILON), s0.sample(1));
private static void checkSelfConsistent(Multinomial<Integer> table) { List<Double> weights = table.getWeights(); double totalWeight = table.getWeight(); double p = 0; int[] k = new int[weights.size()]; for (double weight : weights) { if (weight > 0) { if (p > 0) { k[table.sample(p - 1.0e-9)]++; } k[table.sample(p + 1.0e-9)]++; } p += weight / totalWeight; } k[table.sample(p - 1.0e-9)]++; assertEquals(1, p, 1.0e-9); for (int i = 0; i < weights.size(); i++) { if (table.getWeight(i) > 0) { assertEquals(2, k[i]); } else { assertEquals(0, k[i]); } } } }
Random rand = RandomUtils.getRandom(); Multinomial<Integer> table = new Multinomial<>(); assertEquals(0, table.getWeight(), 1.0e-9); table.add(i, p[i]); total += p[i]; assertEquals(total, table.getWeight(), 1.0e-9); assertEquals(total, table.getWeight(), 1.0e-9); checkSelfConsistent(table); checkSelfConsistent(table); assertEquals(total, table.getWeight(), 1.0e-9); for (int i = 0; i < 10; i++) { assertEquals(p[i], table.getWeight(i), 0); assertEquals(p[i] / total, table.getProbability(i), 1.0e-10); assertEquals(total , table.getWeight(), 1.0e-9); for (int i = 0; i < 10; i++) { assertEquals(p[i], table.getWeight(i), 0); assertEquals(p[i] / total, table.getProbability(i), 1.0e-10); checkSelfConsistent(table); assertEquals(p[i], table.getWeight(i), 0);