public Multinomial(Iterable<WeightedThing<T>> things) { this(); for (WeightedThing<T> thing : things) { add(thing.getValue(), thing.getWeight()); } }
@Override protected Vector computeNext() { if (!projected.hasNext()) { return endOfData(); } return projected.next().getValue(); } };
@Override protected Vector computeNext() { if (!projected.hasNext()) { return endOfData(); } return projected.next().getValue(); } };
@Override protected Vector computeNext() { if (!projected.hasNext()) { return endOfData(); } return projected.next().getValue(); } };
@Override public Vector apply(WeightedThing<Vector> input) { Preconditions.checkArgument(input != null, "input is null"); //noinspection ConstantConditions return input.getValue(); } }
public Multinomial(Iterable<WeightedThing<T>> things) { this(); for (WeightedThing<T> thing : things) { add(thing.getValue(), thing.getWeight()); } }
@Override public List<WeightedThing<Vector>> search(Vector query, int limit) { PriorityQueue<WeightedThing<Vector>> top = searchInternal(query); List<WeightedThing<Vector>> results = Lists.newArrayListWithExpectedSize(top.size()); while (top.size() != 0) { WeightedThing<Vector> wv = top.pop(); results.add(new WeightedThing<Vector>(((HashedVector) wv.getValue()).getVector(), wv.getWeight())); } Collections.reverse(results); if (limit < results.size()) { results = results.subList(0, limit); } return results; }
@Override public List<WeightedThing<Vector>> search(Vector query, int limit) { PriorityQueue<WeightedThing<Vector>> top = searchInternal(query); List<WeightedThing<Vector>> results = Lists.newArrayListWithExpectedSize(top.size()); while (top.size() != 0) { WeightedThing<Vector> wv = top.pop(); results.add(new WeightedThing<Vector>(((HashedVector) wv.getValue()).getVector(), wv.getWeight())); } Collections.reverse(results); if (limit < results.size()) { results = results.subList(0, limit); } return results; }
@Override public List<WeightedThing<Vector>> search(Vector query, int limit) { PriorityQueue<WeightedThing<Vector>> top = searchInternal(query); List<WeightedThing<Vector>> results = Lists.newArrayListWithExpectedSize(top.size()); while (top.size() != 0) { WeightedThing<Vector> wv = top.pop(); results.add(new WeightedThing<>(((HashedVector) wv.getValue()).getVector(), wv.getWeight())); } Collections.reverse(results); if (limit < results.size()) { results = results.subList(0, limit); } return results; }
protected static WeightedThing<Vector> removeHash(WeightedThing<Vector> input) { return new WeightedThing<Vector>(((HashedVector) input.getValue()).getVector(), input.getWeight()); }
protected static WeightedThing<Vector> removeHash(WeightedThing<Vector> input) { return new WeightedThing<>(((HashedVector) input.getValue()).getVector(), input.getWeight()); }
protected static WeightedThing<Vector> removeHash(WeightedThing<Vector> input) { return new WeightedThing<Vector>(((HashedVector) input.getValue()).getVector(), input.getWeight()); }
/** * Computes the summaries for the distances in each cluster. * @param datapoints iterable of datapoints. * @param centroids iterable of Centroids. * @return a list of OnlineSummarizers where the i-th element is the summarizer corresponding to the cluster whose * index is i. */ public static List<OnlineSummarizer> summarizeClusterDistances(Iterable<? extends Vector> datapoints, Iterable<? extends Vector> centroids, DistanceMeasure distanceMeasure) { UpdatableSearcher searcher = new ProjectionSearch(distanceMeasure, 3, 1); searcher.addAll(centroids); List<OnlineSummarizer> summarizers = new ArrayList<>(); if (searcher.size() == 0) { return summarizers; } for (int i = 0; i < searcher.size(); ++i) { summarizers.add(new OnlineSummarizer()); } for (Vector v : datapoints) { Centroid closest = (Centroid)searcher.search(v, 1).get(0).getValue(); OnlineSummarizer summarizer = summarizers.get(closest.getIndex()); summarizer.add(distanceMeasure.distance(v, closest)); } return summarizers; }
private static OnlineSummarizer evaluateStrategy(Matrix testData, BruteSearch ref, LocalitySensitiveHashSearch cut) { OnlineSummarizer t1 = new OnlineSummarizer(); for (int i = 0; i < 100; i++) { final Vector q = testData.viewRow(i); List<WeightedThing<Vector>> v1 = cut.search(q, 150); BitSet b1 = new BitSet(); for (WeightedThing<Vector> v : v1) { b1.set(((WeightedVector)v.getValue()).getIndex()); } List<WeightedThing<Vector>> v2 = ref.search(q, 100); BitSet b2 = new BitSet(); for (WeightedThing<Vector> v : v2) { b2.set(((WeightedVector)v.getValue()).getIndex()); } b1.and(b2); t1.add(b1.cardinality()); } return t1; }
@Test public void testOverlapAndRuntimeSearchFirst() { searcher.clear(); searcher.addAll(dataPoints); Pair<List<WeightedThing<Vector>>, Long> results = getResultsAndRuntimeSearchFirst(searcher, queries); int numFirstMatches = 0; for (int i = 0; i < queries.numRows(); ++i) { WeightedThing<Vector> referenceVector = referenceSearchFirst.getFirst().get(i); WeightedThing<Vector> resultVector = results.getFirst().get(i); if (referenceVector.getValue().equals(resultVector.getValue())) { ++numFirstMatches; } } double bruteSearchAvgTime = reference.getSecond() / (queries.numRows() * 1.0); double searcherAvgTime = results.getSecond() / (queries.numRows() * 1.0); System.out.printf("%s: first matches %d [%d]; avg_time(1 query) %f(s) [%f]\n", searcher.getClass().getName(), numFirstMatches, queries.numRows(), searcherAvgTime, bruteSearchAvgTime); assertEquals("Closest vector returned doesn't match", queries.numRows(), numFirstMatches); assertTrue("Searcher " + searcher.getClass().getName() + " slower than brute", bruteSearchAvgTime > searcherAvgTime); } @Test
@Test public void testSearchFirst() { searcher.clear(); searcher.addAll(dataPoints); for (Vector datapoint : dataPoints) { WeightedThing<Vector> first = searcher.searchFirst(datapoint, false); WeightedThing<Vector> second = searcher.searchFirst(datapoint, true); List<WeightedThing<Vector>> firstTwo = searcher.search(datapoint, 2); assertEquals("First isn't self", 0, first.getWeight(), 0); assertEquals("First isn't self", datapoint, first.getValue()); assertEquals("First doesn't match", first, firstTwo.get(0)); assertEquals("Second doesn't match", second, firstTwo.get(1)); } }
@Test public void testNearMatch() { searcher.clear(); List<MatrixSlice> queries = Lists.newArrayList(Iterables.limit(dataPoints, 100)); searcher.addAllMatrixSlicesAsWeightedVectors(dataPoints); MultiNormal noise = new MultiNormal(0.01, new DenseVector(20)); for (MatrixSlice slice : queries) { Vector query = slice.vector(); final Vector epsilon = noise.sample(); List<WeightedThing<Vector>> r = searcher.search(query, 2); query = query.plus(epsilon); assertEquals("Distance has to be small", epsilon.norm(2), r.get(0).getWeight(), 1.0e-1); assertEquals("Answer must be substantially the same as query", epsilon.norm(2), r.get(0).getValue().minus(query).norm(2), 1.0e-1); assertTrue("Wrong answer must be further away", r.get(1).getWeight() > r.get(0).getWeight()); } }