/** * Adjusts the Hessian's diagonal elements value and computes the next step * * @param lambda (Input) tuning * @param gradient (Input) gradient * @param step (Output) step * @return true if solver could compute the next step */ protected boolean computeStep( double lambda, DMatrixRMaj gradient , DMatrixRMaj step ) { final double mixture = config.mixture; for (int i = 0; i < diagOrig.numRows; i++) { double v = min(config.diagonal_max, max(config.diagonal_min,diagOrig.data[i])); diagStep.data[i] = v + lambda*(mixture + (1.0-mixture)*v); } hessian.setDiagonals( diagStep ); if( !hessian.initializeSolver()) { return false; } // In the book formulation it solves something like (B + lambda*I)*p = -g // but we don't want to modify g, so we apply the negative to the step instead if( hessian.solve(gradient,step) ) { CommonOps_DDRM.scale(-1, step); return true; } else { return false; } }
/** * Adjusts the Hessian's diagonal elements value and computes the next step * * @param lambda (Input) tuning * @param gradient (Input) gradient * @param step (Output) step * @return true if solver could compute the next step */ protected boolean computeStep( double lambda, DMatrixRMaj gradient , DMatrixRMaj step ) { final double mixture = config.mixture; for (int i = 0; i < diagOrig.numRows; i++) { double v = min(config.diagonal_max, max(config.diagonal_min,diagOrig.data[i])); diagStep.data[i] = v + lambda*(mixture + (1.0-mixture)*v); } hessian.setDiagonals( diagStep ); if( !hessian.initializeSolver()) { return false; } // In the book formulation it solves something like (B + lambda*I)*p = -g // but we don't want to modify g, so we apply the negative to the step instead if( hessian.solve(gradient,step) ) { CommonOps_DDRM.scale(-1, step); return true; } else { return false; } }
@Test public void setDiagonals() { DMatrixRMaj M = RandomMatrices_DDRM.rectangle(6,6,rand); setHessian(alg,M); DMatrixRMaj v = RandomMatrices_DDRM.rectangle(6,1,rand); alg.setDiagonals(v); DMatrixRMaj found = RandomMatrices_DDRM.rectangle(6,1,rand); alg.extractDiagonals(found); for (int i = 0; i < M.numRows; i++) { assertEquals(found.get(i),v.get(i), UtilEjml.TEST_F64); } }