/** * Initializes the current instance with a doccat model. Default feature * generation is used. * * @param model the doccat model */ public DocumentCategorizerME(DoccatModel model) { this.model = model; this.mContextGenerator = new DocumentCategorizerContextGenerator(this.model .getFactory().getFeatureGenerators()); }
/** * Initializes the current instance. * * @param samples {@link ObjectStream} of {@link DocumentSample}s */ public DocumentCategorizerEventStream(ObjectStream<DocumentSample> samples) { super(samples); mContextGenerator = new DocumentCategorizerContextGenerator(new BagOfWordsFeatureGenerator()); }
@Override protected DoccatModel loadModel(InputStream in) throws IOException { return new DoccatModel(in); } }
@Test(expected = InsufficientTrainingDataException.class) public void insufficientTestData() throws IOException { ObjectStream<DocumentSample> samples = ObjectStreamUtils.createObjectStream( new DocumentSample("1", new String[]{"a", "b", "c"})); TrainingParameters params = new TrainingParameters(); params.put(TrainingParameters.ITERATIONS_PARAM, 100); params.put(TrainingParameters.CUTOFF_PARAM, 0); DocumentCategorizerME.train("x-unspecified", samples, params, new DoccatFactory()); }
public Event next() { isVirgin = false; return new Event(sample.getCategory(), mContextGenerator.getContext(sample.getText(), sample.getExtraInformation())); }
private static DoccatModel train() throws IOException { return DocumentCategorizerME.train("x-unspecified", createSampleStream(), TrainingParameters.defaultParams(), new DoccatFactory()); }
public static DoccatModel train(String languageCode, ObjectStream<DocumentSample> samples, TrainingParameters mlParams, DoccatFactory factory) throws IOException { Map<String, String> manifestInfoEntries = new HashMap<>(); EventTrainer trainer = TrainerFactory.getEventTrainer( mlParams, manifestInfoEntries); MaxentModel model = trainer.train( new DocumentCategorizerEventStream(samples, factory.getFeatureGenerators())); return new DoccatModel(languageCode, model, manifestInfoEntries, factory); } }
public static DocumentSample createPredSample() { return new DocumentSample("anotherCategory", new String[]{"a", "small", "text"}); }
private static DoccatModel train(DoccatFactory factory) throws IOException { return DocumentCategorizerME.train("x-unspecified", createSampleStream(), TrainingParameters.defaultParams(), factory); }
public FeatureGenerator[] getFeatureGenerators() { if (featureGenerators == null) { if (artifactProvider != null) { String classNames = artifactProvider .getManifestProperty(FEATURE_GENERATORS); if (classNames != null) { this.featureGenerators = loadFeatureGenerators(classNames); } } if (featureGenerators == null) { // could not load using artifact provider // load bag of words as default this.featureGenerators = new FeatureGenerator[]{new BagOfWordsFeatureGenerator()}; } } return featureGenerators; }
/** * Categorizes the given text. * * @param text the text to categorize */ @Override public double[] categorize(String[] text) { return this.categorize(text, Collections.emptyMap()); }
/** * Categorize the given text provided as tokens along with * the provided extra information * * @param text text tokens to categorize * @param extraInformation additional information */ @Override public double[] categorize(String[] text, Map<String, Object> extraInformation) { return model.getMaxentModel().eval( mContextGenerator.getContext(text, extraInformation)); }
public String getBestCategory(double[] outcome) { return model.getMaxentModel().getBestOutcome(outcome); }
@Test public void testEquals() { Assert.assertFalse(createGoldSample() == createGoldSample()); Assert.assertTrue(createGoldSample().equals(createGoldSample())); Assert.assertFalse(createPredSample().equals(createGoldSample())); Assert.assertFalse(createPredSample().equals(new Object())); }
/** * Initializes the current instance via samples and feature generators. * * @param data {@link ObjectStream} of {@link DocumentSample}s * * @param featureGenerators the feature generators */ public DocumentCategorizerEventStream(ObjectStream<DocumentSample> data, FeatureGenerator... featureGenerators) { super(data); mContextGenerator = new DocumentCategorizerContextGenerator(featureGenerators); }
public DoccatModel(String languageCode, MaxentModel doccatModel, Map<String, String> manifestInfoEntries, DoccatFactory factory) { super(COMPONENT_NAME, languageCode, manifestInfoEntries, factory); artifactMap.put(DOCCAT_MODEL_ENTRY_NAME, doccatModel); checkArtifactMap(); }
public static DocumentSample createGoldSample() { return new DocumentSample("aCategory", new String[]{"a", "small", "text"}); }
@Override protected DoccatModel loadModel(InputStream modelIn) throws IOException { return new DoccatModel(modelIn); }
public String getCategory(int index) { return model.getMaxentModel().getOutcome(index); }
public int getNumberOfCategories() { return model.getMaxentModel().getNumOutcomes(); }