public SpTree(SpTree parent, INDArray data, INDArray corner, INDArray width, Set<INDArray> indices, String similarityFunction) { init(parent, data, corner, width, indices, similarityFunction); }
@Override public int compare(HeapObject o1, HeapObject o2) { return Double.compare(o2.getDistance(), o1.getDistance()); } }
@Override public Gradient gradient() { MemoryWorkspace workspace = workspaceMode == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread( workspaceConfigurationExternal, workspaceExternal); try (MemoryWorkspace ws = workspace.notifyScopeEntered()) { if (yIncs == null) yIncs = zeros(Y.shape()); if (gains == null) gains = ones(Y.shape()); AtomicDouble sumQ = new AtomicDouble(0); /* Calculate gradient based on barnes hut approximation with positive and negative forces */ INDArray posF = Nd4j.create(Y.shape()); INDArray negF = Nd4j.create(Y.shape()); if (tree == null) { tree = new SpTree(Y); tree.setWorkspaceMode(workspaceMode); } tree.computeEdgeForces(rows, cols, vals, N, posF); for (int n = 0; n < N; n++) tree.computeNonEdgeForces(n, theta, negF.slice(n), sumQ); INDArray dC = posF.subi(negF.divi(sumQ)); Gradient ret = new DefaultGradient(); ret.gradientForVariable().put(Y_GRAD, dC); return ret; } }
@Override public Gradient gradient() { if (yIncs == null) yIncs = zeros(Y.shape()); if (gains == null) gains = ones(Y.shape()); AtomicDouble sumQ = new AtomicDouble(0); /* Calculate gradient based on barnes hut approximation with positive and negative forces */ INDArray posF = Nd4j.create(Y.shape()); INDArray negF = Nd4j.create(Y.shape()); if (tree == null) tree = new SpTree(Y); tree.computeEdgeForces(rows, cols, vals, N, posF); for (int n = 0; n < N; n++) tree.computeNonEdgeForces(n, theta, negF.slice(n), sumQ); INDArray dC = posF.subi(negF.divi(sumQ)); Gradient ret = new DefaultGradient(); ret.gradientForVariable().put(Y_GRAD, dC); return ret; }
int div = 1; for (int d = 0; d < D; d++) { newWidth.putScalar(d, .5 * boundary.width(d)); if ((i / div) % 2 == 1) newCorner.putScalar(d, boundary.corner(d) - .5 * boundary.width(d)); else newCorner.putScalar(d, boundary.corner(d) + .5 * boundary.width(d)); div *= 2; children[i] = new SpTree(this, data, newCorner, newWidth, indices); for (int j = 0; j < this.numChildren; j++) if (!success) success = children[j].insert(index[i]);
if (cumSize == 0 || (isLeaf() && size == 1 && index[0] == pointIndex)) return; double maxWidth = boundary.width().max(Integer.MAX_VALUE).getDouble(0); if (isLeaf() || maxWidth / FastMath.sqrt(D) < theta) { children[i].computeNonEdgeForces(pointIndex, theta, negativeForce, sumQ);
private boolean insert(int index) { INDArray point = data.slice(index); if (!boundary.contains(point)) return false; cumSize++; double mult1 = (double) (cumSize - 1) / (double) cumSize; double mult2 = 1.0 / (double) cumSize; centerOfMass.muli(mult1); centerOfMass.addi(point.mul(mult2)); // If there is space in this quad tree and it is a leaf, add the object here if (isLeaf() && size < nodeCapacity) { this.index[size] = index; indices.add(point); size++; return true; } for (int i = 0; i < size; i++) { INDArray compPoint = data.slice(this.index[i]); if (compPoint.equals(point)) return true; } if (isLeaf()) subDivide(); // Find out where the point can be inserted for (int i = 0; i < numChildren; i++) { if (children[i].insert(index)) return true; } throw new IllegalStateException("Shouldn't reach this state"); }
@Override public double score() { // Get estimate of normalization term INDArray buff = Nd4j.create(numDimensions); AtomicDouble sum_Q = new AtomicDouble(0.0); for (int n = 0; n < N; n++) tree.computeNonEdgeForces(n, theta, buff, sum_Q); // Loop over all edges to compute t-SNE error double C = .0; INDArray linear = Y; for (int n = 0; n < N; n++) { int begin = rows.getInt(n); int end = rows.getInt(n + 1); int ind1 = n; for (int i = begin; i < end; i++) { int ind2 = cols.getInt(i); buff.assign(linear.slice(ind1)); buff.subi(linear.slice(ind2)); double Q = pow(buff, 2).sum(Integer.MAX_VALUE).getDouble(0); Q = (1.0 / (1.0 + Q)) / sum_Q.doubleValue(); C += vals.getDouble(i) * FastMath.log(vals.getDouble(i) + Nd4j.EPS_THRESHOLD) / (Q + Nd4j.EPS_THRESHOLD); } } return C; }
/** * * @param target * @param k * @param results * @param distances */ public void search(@NonNull INDArray target, int k, List<DataPoint> results, List<Double> distances) { if (items != null) if (!target.isVector() || target.columns() != items.columns() || target.rows() > 1) throw new ND4JIllegalStateException("Target for search should have shape of [" + 1 + ", " + items.columns() + "] but got " + Arrays.toString(target.shape()) + " instead"); k = Math.min(k, items.rows()); results.clear(); distances.clear(); PriorityQueue<HeapObject> pq = new PriorityQueue<>(items.rows(), new HeapObjectComparator()); search(root, target, k + 1, pq, Double.MAX_VALUE); if (pq.size() > k) pq.poll(); while (!pq.isEmpty()) { HeapObject ho = pq.peek(); results.add(new DataPoint(ho.getIndex(), ho.getPoint())); distances.add(ho.getDistance()); pq.poll(); } if (invert) { Collections.reverse(results); Collections.reverse(distances); } }
/** * Verifies the structure of the tree (does bounds checking on each node) * @return true if the structure of the tree * is correct. */ public boolean isCorrect() { for (int n = 0; n < size; n++) { INDArray point = data.slice(index[n]); if (!boundary.contains(point)) return false; } if (!isLeaf()) { boolean correct = true; for (int i = 0; i < numChildren; i++) correct = correct && children[i].isCorrect(); return correct; } return true; }
private void init(SpTree parent, INDArray data, INDArray corner, INDArray width, Set<INDArray> indices, String similarityFunction) { this.parent = parent; D = data.columns(); N = data.rows(); this.similarityFunction = similarityFunction; nodeCapacity = N % NODE_RATIO; index = new int[nodeCapacity]; for (int d = 1; d < this.D; d++) numChildren *= 2; this.indices = indices; isLeaf = true; size = 0; cumSize = 0; children = new SpTree[numChildren]; this.data = data; boundary = new Cell(D); boundary.setCorner(corner.dup()); boundary.setWidth(width.dup()); centerOfMass = Nd4j.create(D); buf = Nd4j.create(D); }
pq.poll(); pq.add(new HeapObject(node.getIndex(), node.getPoint(), distance)); if (pq.size() == k) tau = pq.peek().getDistance();
public SpTree(INDArray data, Set<INDArray> indices, String similarityFunction) { this.indices = indices; this.N = data.rows(); this.D = data.columns(); this.similarityFunction = similarityFunction; INDArray meanY = data.mean(0); INDArray minY = data.min(0); INDArray maxY = data.max(0); INDArray width = Nd4j.create(meanY.shape()); for (int i = 0; i < width.length(); i++) { width.putScalar(i, FastMath.max(maxY.getDouble(i) - meanY.getDouble(i), meanY.getDouble(i) - minY.getDouble(i) + Nd4j.EPS_THRESHOLD)); } init(null, data, meanY, width, indices, similarityFunction); fill(N); }
if (j >= results.size()) break; indices.putScalar(j, results.get(j).getIndex());
private void fill(int n) { if (indices.isEmpty() && parent == null) for (int i = 0; i < n; i++) { log.trace("Inserted " + i); insert(i); } else log.warn("Called fill already"); }
/** * * @param items the items to use * @param similarityFunction the similarity function to use * @param workers number of parallel workers for tree building (increases memory requirements!) * @param invert whether to invert the metric (different optimization objective) */ public VPTree(List<DataPoint> items, String similarityFunction, int workers, boolean invert) { if (this.items == null) { this.items = Nd4j.create(items.size(), items.get(0).getPoint().columns()); } this.workers = workers; for (int i = 0; i < items.size(); i++) { //itemsList.add(items.get(i).getPoint()); this.items.putRow(i, items.get(i).getPoint()); } this.invert = invert; this.similarityFunction = similarityFunction; root = buildFromPoints(this.items); }
public void search() { results = new ArrayList<>(); distances = new ArrayList<>(); //initial search //vpTree.search(target,k,results,distances); //fill till there is k results //by going down the list // if(results.size() < k) { INDArray distancesArr = Nd4j.create(vpTree.getItems().rows(), 1); vpTree.calcDistancesRelativeTo(target, distancesArr); INDArray[] sortWithIndices = Nd4j.sortWithIndices(distancesArr, 0, !vpTree.isInvert()); results.clear(); distances.clear(); if (vpTree.getItems().isVector()) { for (int i = 0; i < k; i++) { int idx = sortWithIndices[0].getInt(i); results.add(new DataPoint(idx, Nd4j.scalar(vpTree.getItems().getDouble(idx)))); distances.add(sortWithIndices[1].getDouble(idx)); } } else { for (int i = 0; i < k; i++) { int idx = sortWithIndices[0].getInt(i); results.add(new DataPoint(idx, vpTree.getItems().getRow(idx))); distances.add(sortWithIndices[1].getDouble(idx)); } } }
AtomicDouble sum_Q = new AtomicDouble(0.0); for (int n = 0; n < N; n++) tree.computeNonEdgeForces(n, theta, buff, sum_Q);