@Override public Gradient gradient() { if (yIncs == null) yIncs = zeros(Y.shape()); if (gains == null) gains = ones(Y.shape()); AtomicDouble sumQ = new AtomicDouble(0); /* Calculate gradient based on barnes hut approximation with positive and negative forces */ INDArray posF = Nd4j.create(Y.shape()); INDArray negF = Nd4j.create(Y.shape()); if (tree == null) tree = new SpTree(Y); tree.computeEdgeForces(rows, cols, vals, N, posF); for (int n = 0; n < N; n++) tree.computeNonEdgeForces(n, theta, negF.slice(n), sumQ); INDArray dC = posF.subi(negF.divi(sumQ)); Gradient ret = new DefaultGradient(); ret.gradientForVariable().put(Y_GRAD, dC); return ret; }
children[i].computeNonEdgeForces(pointIndex, theta, negativeForce, sumQ);
@Override public Gradient gradient() { MemoryWorkspace workspace = workspaceMode == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread( workspaceConfigurationExternal, workspaceExternal); try (MemoryWorkspace ws = workspace.notifyScopeEntered()) { if (yIncs == null) yIncs = zeros(Y.shape()); if (gains == null) gains = ones(Y.shape()); AtomicDouble sumQ = new AtomicDouble(0); /* Calculate gradient based on barnes hut approximation with positive and negative forces */ INDArray posF = Nd4j.create(Y.shape()); INDArray negF = Nd4j.create(Y.shape()); if (tree == null) { tree = new SpTree(Y); tree.setWorkspaceMode(workspaceMode); } tree.computeEdgeForces(rows, cols, vals, N, posF); for (int n = 0; n < N; n++) tree.computeNonEdgeForces(n, theta, negF.slice(n), sumQ); INDArray dC = posF.subi(negF.divi(sumQ)); Gradient ret = new DefaultGradient(); ret.gradientForVariable().put(Y_GRAD, dC); return ret; } }
@Override public double score() { // Get estimate of normalization term INDArray buff = Nd4j.create(numDimensions); AtomicDouble sum_Q = new AtomicDouble(0.0); for (int n = 0; n < N; n++) tree.computeNonEdgeForces(n, theta, buff, sum_Q); // Loop over all edges to compute t-SNE error double C = .0; INDArray linear = Y; for (int n = 0; n < N; n++) { int begin = rows.getInt(n); int end = rows.getInt(n + 1); int ind1 = n; for (int i = begin; i < end; i++) { int ind2 = cols.getInt(i); buff.assign(linear.slice(ind1)); buff.subi(linear.slice(ind2)); double Q = pow(buff, 2).sum(Integer.MAX_VALUE).getDouble(0); Q = (1.0 / (1.0 + Q)) / sum_Q.doubleValue(); C += vals.getDouble(i) * FastMath.log(vals.getDouble(i) + Nd4j.EPS_THRESHOLD) / (Q + Nd4j.EPS_THRESHOLD); } } return C; }
AtomicDouble sum_Q = new AtomicDouble(0.0); for (int n = 0; n < N; n++) tree.computeNonEdgeForces(n, theta, buff, sum_Q);