public static <T> Dictionary<T> readFrom(final Class<T> elementClass, final Reader reader) throws IOException { final Dictionary<T> dictionary = create(); final BufferedReader in = reader instanceof BufferedReader ? (BufferedReader) reader : new BufferedReader(reader); final Function<String, T> valueOfFunction = valueOfFunction(elementClass); String line; while ((line = in.readLine()) != null) { final T element = valueOfFunction.apply(line.trim()); dictionary.map.put(element, dictionary.list.size()); dictionary.list.add(element); } return dictionary; }
public static <T> Dictionary<T> readFrom(final Class<T> elementClass, final Path path) throws IOException { try (InputStream stream = IO.buffer(Files.newInputStream(path))) { if (isTextFormat(path.toString())) { return readFrom(elementClass, IO.utf8Reader(stream)); } else { return readFrom(elementClass, stream); } } }
public void writeTo(final Path path) throws IOException { try (OutputStream stream = IO.buffer(Files.newOutputStream(path, StandardOpenOption.TRUNCATE_EXISTING, StandardOpenOption.CREATE))) { if (isTextFormat(path.toString())) { writeTo(IO.utf8Writer(stream)); } else { writeTo(stream); } } }
private static Classifier trainJava(final Parameters parameters, final Iterable<LabelledVector> trainingSet) throws IOException { // Prepare the svm_parameter object based on supplied parameters final Parameter parameter = encodeParameters(parameters); parameter.setEps(getDefaultEpsilon(parameters) * 0.1f); // Encode the training set as an svm_problem object, filling a dictionary meanwhile final Dictionary<String> dictionary = Dictionary.create(); dictionary.indexFor("_unused"); // just to avoid using feature index 0 final Problem problem = encodeProblem(dictionary, trainingSet, parameters); // Perform training final Model model = Linear.train(problem, parameter); // Compute model hash final StringWriter writer = new StringWriter(); Linear.saveModel(writer, model); final String modelString = writer.toString(); final String modelHash = computeHash(dictionary, modelString); // Build and return the SVM object return new LibLinearClassifier(parameters, modelHash, dictionary, model); }
public static <T> Dictionary<T> readFrom(final Class<T> elementClass, final InputStream stream) throws IOException { final Dictionary<T> dictionary = create(); final ObjectInputStream ois = new ObjectInputStream(stream); final int size = ois.readInt(); try { for (int i = 0; i < size; ++i) { final T element = elementClass.cast(ois.readObject()); dictionary.map.put(element, i); dictionary.list.add(element); } } catch (final ClassNotFoundException ex) { throw new IOException("Invalid file content", ex); } return dictionary; }
static Classifier doRead(final Parameters parameters, final Path path) throws IOException { // Read the dictionary final Dictionary<String> dictionary = Dictionary.readFrom(String.class, path.resolve("dictionary")); // Read the model final String modelString = new String(Files.readAllBytes(path.resolve("model")), Charsets.UTF_8); final Model model = Model.load(new StringReader(modelString)); // Compute model hash final String modelHash = computeHash(dictionary, modelString); // Create and return the SVM return new LibLinearClassifier(parameters, modelHash, dictionary, model); }
@Override void doWrite(final Path path) throws IOException { // Write the dictionary this.dictionary.writeTo(path.resolve("dictionary")); // Write the model try (BufferedWriter writer = Files.newBufferedWriter(path.resolve("model"))) { Linear.saveModel(writer, this.model); } }
public Dictionary<T> freeze() { if (this.list instanceof ImmutableList<?>) { return this; } else { return new Dictionary<T>(ImmutableMap.copyOf(this.map), ImmutableList.copyOf(this.list)); } }
public static <T extends Appendable> T write(final Iterable<? extends Vector> vectors, final Dictionary<String> dictionary, final T out) throws IOException { for (final Vector vector : vectors) { out.append(vector instanceof LabelledVector ? Integer .toString(((LabelledVector) vector).getLabel()) : "0"); final int size = vector.size(); final List<IndexValue> vs = Lists.newArrayListWithCapacity(size); for (int i = 0; i < size; ++i) { final String feature = vector.getFeature(i); if (feature.charAt(0) != '_') { final Integer featureIndex = dictionary.indexFor(vector.getFeature(i)); if (featureIndex != null) { vs.add(new IndexValue(featureIndex, vector.getValue(i))); } } } Collections.sort(vs); for (final IndexValue v : vs) { out.append(' '); out.append(Integer.toString(v.index)); out.append(':'); out.append(Float.toString(v.value)); } out.append('\n'); } return out; }
private LibSvmClassifier(final Parameters parameters, final String modelHash, final Dictionary<String> dictionary, final svm_model model) { super(parameters, modelHash); this.dictionary = dictionary.freeze(); this.model = model; }
@Nullable public T elementFor(final int index) { try { return this.list.get(index); } catch (final IndexOutOfBoundsException ex) { throw new IllegalArgumentException( "No element for index " + index + " (size is " + size() + ")"); } }
private static Classifier trainJava(final Parameters parameters, final Iterable<LabelledVector> trainingSet) throws IOException { // Prepare the svm_parameter object based on supplied parameters final svm_parameter parameter = encodeParameters(parameters); // Encode the training set as an svm_problem object, filling a dictionary meanwhile final Dictionary<String> dictionary = Dictionary.create(); final svm_problem problem = encodeProblem(dictionary, trainingSet); // Perform training final svm_model model = svm.svm_train(problem, parameter); // Compute model hash, by saving and reloading SVM model final File tmpFile = File.createTempFile("svm", ".bin"); tmpFile.deleteOnExit(); svm.svm_save_model(tmpFile.getAbsolutePath(), model); final String modelString = com.google.common.io.Files.toString(tmpFile, Charset.defaultCharset()); final String modelHash = computeHash(dictionary, modelString); final svm_model reloadedModel = svm .svm_load_model(new BufferedReader(new StringReader(modelString))); tmpFile.delete(); // Build and return the SVM object return new LibSvmClassifier(parameters, modelHash, dictionary, reloadedModel); }
static Classifier doRead(final Parameters parameters, final Path path) throws IOException { // Read the dictionary final Dictionary<String> dictionary = Dictionary.readFrom(String.class, path.resolve("dictionary")); // Read the model final String modelString = new String(Files.readAllBytes(path.resolve("model")), Charsets.UTF_8); final svm_model model = svm .svm_load_model(new BufferedReader(new StringReader(modelString))); // Compute model hash final String modelHash = computeHash(dictionary, modelString); // Create and return the SVM return new LibSvmClassifier(parameters, modelHash, dictionary, model); }
@Override void doWrite(final Path path) throws IOException { // Write the dictionary this.dictionary.writeTo(path.resolve("dictionary")); // Write the model final File tmpFile = File.createTempFile("svm", ".bin"); tmpFile.deleteOnExit(); svm.svm_save_model(tmpFile.getAbsolutePath(), this.model); final String modelString = com.google.common.io.Files.toString(tmpFile, Charset.defaultCharset()); tmpFile.delete(); Files.write(path.resolve("model"), modelString.getBytes(Charsets.UTF_8)); }
public static <T> Dictionary<T> create() { return new Dictionary<T>(new HashMap<T, Integer>(), new ArrayList<T>()); }
private static Feature[] encodeVector(final Dictionary<String> dictionary, final Vector vector) { final int size = vector.size(); Feature[] features = new Feature[size]; int index = 0; for (int i = 0; i < size; ++i) { final String feature = vector.getFeature(i); if (feature.charAt(0) != '_') { final Integer featureIndex = dictionary.indexFor(vector.getFeature(i)); if (featureIndex != null) { features[index++] = new FeatureNode(featureIndex, vector.getValue(i)); } } } if (index < size) { features = Arrays.copyOfRange(features, 0, index); } Arrays.sort(features, FEATURE_ORDERING); return features; }
final Dictionary<String> dictionary = Dictionary.create(); final File trainingFile = File.createTempFile("svmdata.", ".txt"); trainingFile.deleteOnExit();
public static <T> Dictionary<T> create(final Dictionary<T> dictionary) { return new Dictionary<T>(new HashMap<T, Integer>(dictionary.map), new ArrayList<T>(dictionary.list)); }
private static svm_node[] encodeVector(final Dictionary<String> dictionary, final Vector vector) { final int size = vector.size(); svm_node[] nodes = new svm_node[size]; int index = 0; for (int i = 0; i < size; ++i) { final String feature = vector.getFeature(i); if (feature.charAt(0) != '_') { final Integer featureIndex = dictionary.indexFor(vector.getFeature(i)); if (featureIndex != null) { final svm_node node = new svm_node(); node.index = featureIndex; node.value = vector.getValue(i); nodes[index++] = node; } } } if (index < size) { nodes = Arrays.copyOfRange(nodes, 0, index); } Arrays.sort(nodes, NODE_ORDERING); return nodes; }