@CommandDescription(description = "<DESCRIPTION>\n" + "\tThis procedure is used to extract the weight vector from a saved model \n" + "<INPUT>\n" + "\tIt receives 2 different arguments. \n" + "\t1) model_file (a string), the file name of a trained model \n" + "\t2) output_file (a string), the file name that will be used to put the contain of the weight vector. \n" + "<OUTPUT>\n" + "\tThe weight vector will be output in the ${output_file}.") public static void outputWeightVector(String model_name, String output_file) throws IOException, ClassNotFoundException { JLISModelIOManager io = new JLISModelIOManager(); RerankerModel model = (RerankerModel) io.loadModel(model_name); ArrayList<String> out = new ArrayList<String>(); double[] w = model.wv.getInternalArray(); for (int i = 0; i < w.length; i++) out.add(i + ":" + w[i]); LineIO.write(output_file, out); System.out.println("Finish putting the weight vector at " + output_file); } }
@CommandDescription(description = "<DESCRIPTION>\n" + "\tThis procedure is used to train a reranker model using ssvm \n" + "<INPUT>\n" + "\tIt receives 3 different arguments. \n" + "\t1) train_file (a string), the file name of the training data \n" + "\t2) C (a real number > 0), a regularization parameter. \n" + "\t3) n_thread (an integer), which indicates how many thread you want to use (you do not want to use more threads than the number of cores you have in your computer). \n" + "<OUTPUT>\n" + "\tThe trained model file will be saved to ${train_file}.ssvm.model in the current working directory.") public static void trainReranker(String train_name, String C_st_str, String n_thread_str) throws Exception { StructuredProblem sp = RerankDataReader.readFeatureFile(train_name); RerankerModel model = RerankTrainer.trainRerankerModel( Double.parseDouble(C_st_str), Integer.parseInt(n_thread_str), sp); String model_name = JLISUtils.getFileNameWithoutDir(train_name) + ".ssvm.model"; JLISModelIOManager io = new JLISModelIOManager(); io.saveModel(model, model_name); }
@CommandDescription(description = "<DESCRIPTION>\n" + "\tThis procedure is used to extract the weight vector from a saved model \n" + "<INPUT>\n" + "\tIt receives 2 different arguments. \n" + "\t1) model_file (a string), the file name of a trained model \n" + "\t2) output_file (a string), the file name that will be used to put the contain of the weight vector. \n" + "<OUTPUT>\n" + "\tThe weight vector will be output in the ${output_file}.") public static void outputWeightVector(String model_name, String output_file) throws IOException, ClassNotFoundException { JLISModelIOManager io = new JLISModelIOManager(); MulticlassModel model = (MulticlassModel) io.loadModel(model_name); String[] reverse = model.getReverseMapping(); ArrayList<String> out = new ArrayList<String>(); int start = 0; double[] w = model.wv.getInternalArray(); for (int i = 0; i < reverse.length; i++) { out.add("Label:" + reverse[i]); for (int t = 0; t < model.n_base_feature_in_train; t++) { if (t == model.n_base_feature_in_train - 1) out.add(t + ":" + w[start + t] + " (bias)"); else out.add(t + ":" + w[start + t]); } start += model.n_base_feature_in_train; } LineIO.write(output_file, out); System.out.println("Finish putting the weight vector at " + output_file); }
@CommandDescription(description = "<DESCRIPTION>\n" + "\tThis procedure is used to train a standard multiclass classification model using ssvm \n" + "<INPUT>\n" + "\tIt receives 3 different arguments. \n" + "\t1) train_file (a string), the file name of the training data \n" + "\t2) C (a real number > 0), a regularization parameter. \n" + "\t3) n_thread (an integer), which indicates how many thread you want to use (you do not want to use more threads than the number of cores you have in your computer). \n" + "<OUTPUT>\n" + "\tThe trained model file will be saved to ${train_file}.ssvm.model in the current working directory.") public static void trainMultiClass(String train_name, String C_st_str, String n_thread_str) throws Exception { LabeledMulticlassData train = MultiClassSparseLabeledDataReader .readTrainingData(train_name); MulticlassModel model = MultiClassTrainer.trainMultiClassModel( Double.parseDouble(C_st_str), Integer.parseInt(n_thread_str), train); String model_name = JLISUtils.getFileNameWithoutDir(train_name) + ".ssvm.model"; JLISModelIOManager io = new JLISModelIOManager(); io.saveModel(model, model_name); }
public static void testMultiClass(String model_name, String test_name, String output_name) throws Exception { JLISModelIOManager io = new JLISModelIOManager(); MulticlassModel model = (MulticlassModel) io.loadModel(model_name); LabeledMulticlassData test = MultiClassSparseLabeledDataReader .readTestingData(test_name, model.lab_mapping,
@CommandDescription(description = "<DESCRIPTION>\n" + "\tThis procedure is used to train a cost-sensitive multiclass classification model using ssvm \n" + "<INPUT>\n" + "\tIt receives 4 different arguments. \n" + "\t1) train_file (a string), the file name of the training data \n" + "\t2) cost_matrix_file (a string), the file name of a cost matrix. \n" + "\t3) C (a real number > 0), a regularization parameter. \n" + "\t4) n_thread (a integer), which indicates how many thread you want to use (you do not want to use more threads than the number of cores you have in your computer). \n" + "<OUTPUT>\n" + "\tThe trained model file will be saved to ${train_file}.ssvm.model in the current working directory.") public static void trainCostSensitiveMultiClass(String train_name, String cost_matrix_file_name, String C_st_str, String n_thread_str) throws Exception { LabeledMulticlassData train = MultiClassSparseLabeledDataReader .readTrainingData(train_name); double[][] cost_matrix = MultiClassSparseLabeledDataReader .getCostMatrix(train.label_mapping, cost_matrix_file_name); MulticlassModel model = MultiClassTrainer .trainCostSensitiveMultiClassModel( Double.parseDouble(C_st_str), Integer.parseInt(n_thread_str), train, cost_matrix); String model_name = JLISUtils.getFileNameWithoutDir(train_name) + ".ssvm.model"; JLISModelIOManager io = new JLISModelIOManager(); io.saveModel(model, model_name); }
public static void testReranker(String model_name, String test_name, String output_name) throws Exception { JLISModelIOManager io = new JLISModelIOManager(); RerankerModel model = (RerankerModel) io.loadModel(model_name); StructuredProblem sp = RerankDataReader.readFeatureFile(test_name);
String cost_matrix_file_name, String test_name, String output_name) throws Exception { JLISModelIOManager io = new JLISModelIOManager(); MulticlassModel model = (MulticlassModel) io.loadModel(model_name);