@Override public double getLossWeightAlphaSum() { return getAlphaSum(); }
@Override public double getLossWeightAlphaSum() { return getAlphaSum(); }
/** * The function is for running LCLR with square hinge-loss. * <p> * * Remember ALWAYS use * {@link WeightVector#predictLCLRBinaryScore(IInstance, AbstractStructureFinder)} * to get the prediction score for binary examples * * @param init_wv * The initial weight vector for the input. Given that this * learning algorithm is not convex, a good initialization point * is important * @param struct_finder * The inference solver (dynamic programming, ILP,...). Given an * input (IInstance) and a Weight vector (WeightVector), return * the best structure (AbstractStructures) * @param bp * Binary labeled dataset * @param para * Parameters for JLIS * @return * @throws Exception */ @Override public WeightVector trainLCLR(final WeightVector init_wv, final AbstractStructureFinder struct_finder, final BinaryProblem bp, final JLISParameters para) throws Exception { return getJointWeightVectorFast(init_wv, struct_finder, empty_s, bp, para); }
public static RerankerModel trainRerankerModel(double C, int n_thread, StructuredProblem train) throws Exception { RerankerModel model = new RerankerModel(); model.para = new JLISParameters(); // para.total_number_features = train.label_mapping.size() * // train.n_base_feature_in_train; model.para.c_struct = C; model.para.TRAINMINI = true; // play with the following two parameters if you want to solve SSVM more // tightly model.para.DUAL_GAP = 0.01; model.para.WORKINGSETSVM_STOP = 0.01; System.out.println("Initializing Solvers..."); System.out.flush(); AbstractLossSensitiveStructureFinder[] s_finder_list = new AbstractLossSensitiveStructureFinder[n_thread]; for (int i = 0; i < s_finder_list.length; i++) { s_finder_list[i] = new RerankerBestItemFinder(); } System.out.println("Done!"); System.out.flush(); L2LossParallelJLISLearner learner = new L2LossParallelJLISLearner(); // train model model.wv = learner.parallelTrainStructuredSVM(s_finder_list, train, model.para); return model; }
WeightVector cur_wv = getWeightVectorBySumAlpahFv(alpha_ins_list, is_extendable, n_ex, verbose_level); L2SolverInfo si = new L2SolverInfo(); alpha_ins_list[idx].solveSubProblemAndUpdateW(si, cur_wv); double obj = getDualObjectiveWithCurrentCuts(alpha_ins_list, cur_wv); return new WorkingSetSVMResult(cur_wv, obj, finished);
protected void updateStructuresForBinaryPositiveExamples( WeightVector new_wv, L2LossInstanceWithAlphas[] alpha_ins_list, AbstractStructureFinder struct_finder) throws Exception { // update positive examples int n_p = 0; int n_p_changed = 0; int total_size = alpha_ins_list.length; for (int i = 0; i < total_size; i++) { L2LossInstanceWithAlphas p_ins = alpha_ins_list[i]; if (i % verbose_step == 0) System.out.println("positive_xi inference stage: " + i + "/" + total_size); if (p_ins.isBinary() && p_ins.getY() == 1) { n_p += 1; double updated = p_ins.updateRepresentationCollection(new_wv, struct_finder); if (updated > 0) n_p_changed += 1; } } System.out.println("Among " + n_p + " examples, " + n_p_changed + "updated "); }
WeightVector init_wv = parallelTrainStructuredSVM(struct_finder_list, sp, para); WeightVector res_wv = multiThreadGetJointWeightVector(init_wv, struct_finder_list, sp, bp, para);
public static double getDualObjectiveWithCurrentCuts( L2LossInstanceWithAlphas[] alpha_ins_list, WeightVector cur_wv) { double obj = 0; obj += cur_wv.getTwoNormSquare() * 0.5; for (int i = 0; i < alpha_ins_list.length; i++) { L2LossInstanceWithAlphas instanceWithAlphas = alpha_ins_list[i]; double w_sum = instanceWithAlphas.getLossWeightAlphaSum(); double sum = instanceWithAlphas.getAlphaSum(); double C = instanceWithAlphas.getC(); obj -= w_sum; obj += (1.0 / (4.0 * C)) * sum * sum; } return obj; }
subset.add(alpha_ins_list[j]); inf_runner_list[i] = new PositiveInferenceHandler( struct_finder_list[i], subset, new_wv, verbose_level); inf_runner_list[i].start(); inf_runner_list[i].join();
alpha_ins_list[i] = new L2LossStructureInstanceWithAlphas( sp.input_list.get(i), sp.output_list.get(i), C_structure); alpha_ins_list[i] = new L2LossStructureInstanceWithAlphas( sp.input_list.get(i), sp.output_list.get(i), C_structure * sp.weight_list.get(i)); for (int i = 0; i < bp.size(); i++) if (bp.output_list.get(i) == 1) alpha_ins_list[i + struct_size] = new L2LossPositiveInstanceWithAlphas( bp.output_list.get(i), bp.input_list.get(i), C_binary); else alpha_ins_list[i + struct_size] = new L2LossNegativeInstanceWithAlphas( bp.output_list.get(i), bp.input_list.get(i), C_binary); for (int i = 0; i < bp.size(); i++) { if (bp.output_list.get(i) == 1) alpha_ins_list[i + struct_size] = new L2LossPositiveInstanceWithAlphas( bp.output_list.get(i), bp.input_list.get(i), C_binary * bp.weight_list.get(i)); else alpha_ins_list[i + struct_size] = new L2LossNegativeInstanceWithAlphas( bp.output_list.get(i), bp.input_list.get(i), C_binary * bp.weight_list.get(i));
subset.add(alpha_ins_list[j]); inf_runner_list[i] = new NegAndStructInferenceHandler( struct_finder_list[i], subset, new_wv, verbose_level); inf_runner_list[i].start(); inf_runner_list[i].join();
public static WeightVector getWeightVectorBySumAlpahFv( L2LossInstanceWithAlphas[] alpha_ins_list, boolean is_extendable, int n_ex, int verbose_level) { int max_n = -1; for (int i = 0; i < n_ex; i++) { int cur_idx = alpha_ins_list[i].getMaxIdx(); if (cur_idx > max_n) max_n = cur_idx; } if (verbose_level >= JLISParameters.VLEVEL_MID) System.out.println("number of features: " + max_n); WeightVector cur_wv = new WeightVector(max_n + 1); cur_wv.setExtendable(is_extendable); // double[] cur_w = new double[max_n + 1]; for (int i = 0; i < n_ex; i++) { alpha_ins_list[i].fillWeightVector(cur_wv); } return cur_wv; }
protected WeightVector multiThreadGetJointWeightVector(WeightVector old_wv, final AbstractStructureFinder[] struct_finder_list, StructuredProblem sp, BinaryProblem bp, JLISParameters para) throws Exception { int struct_size = sp.size(); int binary_size = bp.size(); int total_size = struct_size + binary_size; System.out.println("Number of traing data: #struct: " + struct_size + " #binary: " + binary_size); WeightVector new_wv = new WeightVector(old_wv, 0); // allocate bias term // for indirect // supervision L2LossInstanceWithAlphas[] alpha_ins_list = initArrayOfInstances(sp, bp, para.c_struct, para.c_binary, struct_size, total_size); return multitreadTrainJLIS(struct_finder_list, para.MAX_OUT_ITER, struct_size, total_size, new_wv, alpha_ins_list, para) .getFirst(); }
final BinaryProblem bp, final JLISParameters para) throws Exception { return multiThreadGetJointWeightVector(init_wv, struct_finder_list, empty_s, bp, para);
public static MulticlassModel trainMultiClassModel(double C, int n_thread, LabeledMulticlassData train) throws Exception { MulticlassModel model = new MulticlassModel(); model.lab_mapping = train.label_mapping; // for the bias term model.n_base_feature_in_train = train.n_base_feature_in_train; model.para = new JLISParameters(); // para.total_number_features = train.label_mapping.size() * // train.n_base_feature_in_train; model.para.c_struct = C; model.para.TRAINMINI = true; // play with the following two parameters if you want to solve SSVM more // tightly model.para.DUAL_GAP = 0.01; model.para.WORKINGSETSVM_STOP = 0.01; System.out.println("Initializing Solvers..."); System.out.flush(); AbstractLossSensitiveStructureFinder[] s_finder_list = new AbstractLossSensitiveStructureFinder[n_thread]; for (int i = 0; i < s_finder_list.length; i++) { s_finder_list[i] = new MultiClassStructureFinder(); } System.out.println("Done!"); System.out.flush(); model.s_finder = s_finder_list[0]; L2LossParallelJLISLearner learner = new L2LossParallelJLISLearner(); // train model model.wv = learner.parallelTrainStructuredSVM(s_finder_list, train.sp, model.para); return model; }
private Pair<Integer, Integer> updateStructuresforNegativeAndStructuredExamples( L2LossInstanceWithAlphas[] alpha_ins_list, WeightVector new_wv, int struct_size, AbstractStructureFinder struct_finder) throws Exception { int n_s_new = 0; int n_b_new = 0; int total_size = alpha_ins_list.length; for (int i = 0; i < total_size; i++) { // positive h has already been fixed if (alpha_ins_list[i].isBinary() && alpha_ins_list[i].getY() == 1) continue; double score = alpha_ins_list[i].updateRepresentationCollection( new_wv, struct_finder); if (i < struct_size) { if (score > L2LossInstanceWithAlphas.DUAL_GAP) { n_s_new += 1; } } else { if (score > L2LossInstanceWithAlphas.BINARY_DUAL_GAP) { n_b_new += 1; } } } return new Pair<Integer, Integer>(n_s_new, n_b_new); }
/** * The function for the users to call for the structured SVM * * @param struct_finder * The inference solver (dynamic programming, ILP,...). Given an * input (IInstance) and a Weight vector (WeightVector), return * the best structure (AbstractStructures) * @param sp * Structured Labeled Dataset * @param para * parameters for JLIS * @return * @throws Exception */ @Override public WeightVector trainStructuredSVM( final AbstractLossSensitiveStructureFinder struct_finder, final StructuredProblem sp, JLISParameters para) throws Exception { WeightVector wv = new WeightVector(para.total_number_features + 1); // +1 // because // we // skip // wv.u[0] //wv.setExtendable(false); return getJointWeightVectorFast(wv, struct_finder, sp, empty_b, para); }
public WeightVector parallelTrainLatentStructuredSVMWithInitStructures_old( final AbstractLatentLossSensitiveStructureFinder[] struct_finder_list, final StructuredProblem sp, final JLISParameters para) throws Exception { WeightVector wv = new WeightVector(para.total_number_features + 1); // +1 for (int i = 0; i < para.MAX_OUT_ITER; i++) { wv = multiThreadGetJointWeightVector(wv, struct_finder_list, sp, empty_b, para); for (int j = 0; j < sp.size(); j++) { IStructure newLatentStructureWithSameOutputStructure = struct_finder_list[0] .getBestLatentStructure(wv, sp.input_list.get(j), sp.output_list.get(j)); sp.output_list .set(j, newLatentStructureWithSameOutputStructure); } } return wv; }
@Override public void run() { int index = 0; for (L2LossInstanceWithAlphas ins : alpha_ins_list) { // focus on positive know if (!ins.isBinary() || ins.getY() == -1) continue; double score = 0; try { score = ins.updateRepresentationCollection(wv, s_finder); } catch (Exception e) { e.printStackTrace(); System.exit(1); } if (score > L2LossInstanceWithAlphas.BINARY_DUAL_GAP) { n_b_new += 1; } // System.out.println("now: " + index); index++; } if (verbose_level >= JLISParameters.VLEVEL_HIGH) { System.out.println("Thread: (b) udpate = " + n_b_new); } } }
if (ins.isBinary() && ins.getY() == 1) continue; double score = 0; try { score = ins.updateRepresentationCollection(wv, s_finder); } catch (Exception e) { e.printStackTrace(); if (ins.isBinary()) { if (score > L2LossInstanceWithAlphas.BINARY_DUAL_GAP) { n_b_new += 1;