private static DoccatModel train() throws IOException { return DocumentCategorizerME.train("x-unspecified", createSampleStream(), TrainingParameters.defaultParams(), new DoccatFactory()); }
@Test public void testCustom() throws IOException { FeatureGenerator[] featureGenerators = { new BagOfWordsFeatureGenerator(), new NGramFeatureGenerator(), new NGramFeatureGenerator(2,3) }; DoccatFactory factory = new DoccatFactory(featureGenerators); DoccatModel model = train(factory); Assert.assertNotNull(model); ByteArrayOutputStream out = new ByteArrayOutputStream(); model.serialize(out); ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); DoccatModel fromSerialized = new DoccatModel(in); factory = fromSerialized.getFactory(); Assert.assertNotNull(factory); Assert.assertEquals(3, factory.getFeatureGenerators().length); Assert.assertEquals(BagOfWordsFeatureGenerator.class, factory.getFeatureGenerators()[0].getClass()); Assert.assertEquals(NGramFeatureGenerator.class, factory.getFeatureGenerators()[1].getClass()); Assert.assertEquals(NGramFeatureGenerator.class,factory.getFeatureGenerators()[2].getClass()); }
@Override public Map<String, String> createManifestEntries() { Map<String, String> manifestEntries = super.createManifestEntries(); if (getFeatureGenerators() != null) { manifestEntries.put(FEATURE_GENERATORS, featureGeneratorsAsString()); } return manifestEntries; }
public static DoccatFactory create(String subclassName, FeatureGenerator[] featureGenerators) throws InvalidFormatException { if (subclassName == null) { // will create the default factory return new DoccatFactory(featureGenerators); } try { DoccatFactory theFactory = ExtensionLoader.instantiateExtension( DoccatFactory.class, subclassName); theFactory.init(featureGenerators); return theFactory; } catch (Exception e) { String msg = "Could not instantiate the " + subclassName + ". The initialization throw an exception."; System.err.println(msg); e.printStackTrace(); throw new InvalidFormatException(msg, e); } }
private String featureGeneratorsAsString() { List<FeatureGenerator> fgs = Arrays.asList(getFeatureGenerators()); Iterator<FeatureGenerator> iter = fgs.iterator(); StringBuilder sb = new StringBuilder(); if (iter.hasNext()) { sb.append(iter.next().getClass().getCanonicalName()); while (iter.hasNext()) { sb.append(',').append(iter.next().getClass().getCanonicalName()); } } return sb.toString(); }
DoccatFactory factory = DoccatFactory.create(params.getFactory(), featureGenerators); validator = new DoccatCrossValidator(params.getLang(), mlParams, factory, listenersArr);
public FeatureGenerator[] getFeatureGenerators() { if (featureGenerators == null) { if (artifactProvider != null) { String classNames = artifactProvider .getManifestProperty(FEATURE_GENERATORS); if (classNames != null) { this.featureGenerators = loadFeatureGenerators(classNames); } } if (featureGenerators == null) { // could not load using artifact provider // load bag of words as default this.featureGenerators = new FeatureGenerator[]{new BagOfWordsFeatureGenerator()}; } } return featureGenerators; }
public static DoccatFactory create(String subclassName, FeatureGenerator[] featureGenerators) throws InvalidFormatException { if (subclassName == null) { // will create the default factory return new DoccatFactory(featureGenerators); } try { DoccatFactory theFactory = ExtensionLoader.instantiateExtension( DoccatFactory.class, subclassName); theFactory.init(featureGenerators); return theFactory; } catch (Exception e) { String msg = "Could not instantiate the " + subclassName + ". The initialization throw an exception."; System.err.println(msg); e.printStackTrace(); throw new InvalidFormatException(msg, e); } }
/** * Initializes the current instance with a doccat model. Default feature * generation is used. * * @param model the doccat model */ public DocumentCategorizerME(DoccatModel model) { this.model = model; this.mContextGenerator = new DocumentCategorizerContextGenerator(this.model .getFactory().getFeatureGenerators()); }
@Override public void run(String format, String[] args) { super.run(format, args); mlParams = CmdLineUtil.loadTrainingParameters(params.getParams(), false); if (mlParams == null) { mlParams = ModelUtil.createDefaultTrainingParameters(); } File modelOutFile = params.getModel(); CmdLineUtil.checkOutputFile("document categorizer model", modelOutFile); FeatureGenerator[] featureGenerators = createFeatureGenerators(params .getFeatureGenerators()); DoccatModel model; try { DoccatFactory factory = DoccatFactory.create(params.getFactory(), featureGenerators); model = DocumentCategorizerME.train(params.getLang(), sampleStream, mlParams, factory); } catch (IOException e) { throw createTerminationIOException(e); } finally { try { sampleStream.close(); } catch (IOException e) { // sorry that this can fail } } CmdLineUtil.writeModel("document categorizer", modelOutFile, model); }
public FeatureGenerator[] getFeatureGenerators() { if (featureGenerators == null) { if (artifactProvider != null) { String classNames = artifactProvider .getManifestProperty(FEATURE_GENERATORS); if (classNames != null) { this.featureGenerators = loadFeatureGenerators(classNames); } } if (featureGenerators == null) { // could not load using artifact provider // load bag of words as default this.featureGenerators = new FeatureGenerator[]{new BagOfWordsFeatureGenerator()}; } } return featureGenerators; }
@Test(expected = InsufficientTrainingDataException.class) public void insufficientTestData() throws IOException { ObjectStream<DocumentSample> samples = ObjectStreamUtils.createObjectStream( new DocumentSample("1", new String[]{"a", "b", "c"})); TrainingParameters params = new TrainingParameters(); params.put(TrainingParameters.ITERATIONS_PARAM, 100); params.put(TrainingParameters.CUTOFF_PARAM, 0); DocumentCategorizerME.train("x-unspecified", samples, params, new DoccatFactory()); }
public static DoccatFactory create(String subclassName, FeatureGenerator[] featureGenerators) throws InvalidFormatException { if (subclassName == null) { // will create the default factory return new DoccatFactory(featureGenerators); } try { DoccatFactory theFactory = ExtensionLoader.instantiateExtension( DoccatFactory.class, subclassName); theFactory.init(featureGenerators); return theFactory; } catch (Exception e) { String msg = "Could not instantiate the " + subclassName + ". The initialization throw an exception."; System.err.println(msg); e.printStackTrace(); throw new InvalidFormatException(msg, e); } }
public static DoccatModel train(String languageCode, ObjectStream<DocumentSample> samples, TrainingParameters mlParams, DoccatFactory factory) throws IOException { Map<String, String> manifestInfoEntries = new HashMap<>(); EventTrainer trainer = TrainerFactory.getEventTrainer( mlParams, manifestInfoEntries); MaxentModel model = trainer.train( new DocumentCategorizerEventStream(samples, factory.getFeatureGenerators())); return new DoccatModel(languageCode, model, manifestInfoEntries, factory); } }
@Override public Map<String, String> createManifestEntries() { Map<String, String> manifestEntries = super.createManifestEntries(); if (getFeatureGenerators() != null) { manifestEntries.put(FEATURE_GENERATORS, featureGeneratorsAsString()); } return manifestEntries; }
DoccatFactory factory = DoccatFactory.create(params.getFactory(), featureGenerators); validator = new DoccatCrossValidator(params.getLang(), mlParams, factory, listenersArr);
public FeatureGenerator[] getFeatureGenerators() { if (featureGenerators == null) { if (artifactProvider != null) { String classNames = artifactProvider .getManifestProperty(FEATURE_GENERATORS); if (classNames != null) { this.featureGenerators = loadFeatureGenerators(classNames); } } if (featureGenerators == null) { // could not load using artifact provider // load bag of words as default this.featureGenerators = new FeatureGenerator[]{new BagOfWordsFeatureGenerator()}; } } return featureGenerators; }
@Test public void testSimpleTraining() throws IOException { ObjectStream<DocumentSample> samples = ObjectStreamUtils.createObjectStream( new DocumentSample("1", new String[]{"a", "b", "c"}), new DocumentSample("1", new String[]{"a", "b", "c", "1", "2"}), new DocumentSample("1", new String[]{"a", "b", "c", "3", "4"}), new DocumentSample("0", new String[]{"x", "y", "z"}), new DocumentSample("0", new String[]{"x", "y", "z", "5", "6"}), new DocumentSample("0", new String[]{"x", "y", "z", "7", "8"})); TrainingParameters params = new TrainingParameters(); params.put(TrainingParameters.ITERATIONS_PARAM, 100); params.put(TrainingParameters.CUTOFF_PARAM, 0); DoccatModel model = DocumentCategorizerME.train("x-unspecified", samples, params, new DoccatFactory()); DocumentCategorizer doccat = new DocumentCategorizerME(model); double[] aProbs = doccat.categorize(new String[]{"a"}); Assert.assertEquals("1", doccat.getBestCategory(aProbs)); double[] bProbs = doccat.categorize(new String[]{"x"}); Assert.assertEquals("0", doccat.getBestCategory(bProbs)); //test to make sure sorted map's last key is cat 1 because it has the highest score. SortedMap<Double, Set<String>> sortedScoreMap = doccat.sortedScoreMap(new String[]{"a"}); Set<String> cat = sortedScoreMap.get(sortedScoreMap.lastKey()); Assert.assertEquals(1, cat.size()); }
@Test public void testDefault() throws IOException { DoccatModel model = train(); Assert.assertNotNull(model); ByteArrayOutputStream out = new ByteArrayOutputStream(); model.serialize(out); ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); DoccatModel fromSerialized = new DoccatModel(in); DoccatFactory factory = fromSerialized.getFactory(); Assert.assertNotNull(factory); Assert.assertEquals(1, factory.getFeatureGenerators().length); Assert.assertEquals(BagOfWordsFeatureGenerator.class, factory.getFeatureGenerators()[0].getClass()); }
@Override public Map<String, String> createManifestEntries() { Map<String, String> manifestEntries = super.createManifestEntries(); if (getFeatureGenerators() != null) { manifestEntries.put(FEATURE_GENERATORS, featureGeneratorsAsString()); } return manifestEntries; }