@Before public void setup() { mlParams = new TrainingParameters(); mlParams.put(TrainingParameters.ALGORITHM_PARAM, GISTrainer.MAXENT_VALUE); mlParams.put(TrainingParameters.ITERATIONS_PARAM, 10); mlParams.put(TrainingParameters.CUTOFF_PARAM, 5); }
/** * Retrieves all parameters without a name space. * * @return the settings map * * @deprecated use {@link #getObjectSettings()} instead */ public Map<String, String> getSettings() { return getSettings(null); }
public static TrainingParameters loadTrainingParams(String inFileValue, boolean isSequenceTrainingAllowed) throws ResourceInitializationException { TrainingParameters params; if (inFileValue != null) { try (InputStream paramsIn = new FileInputStream(new File(inFileValue))) { params = new opennlp.tools.util.TrainingParameters(paramsIn); } catch (IOException e) { throw new ResourceInitializationException(e); } if (!TrainerFactory.isValid(params)) { throw new ResourceInitializationException(new Exception("Training parameters file is invalid!")); } TrainerFactory.TrainerType trainerType = TrainerFactory.getTrainerType(params); if (!isSequenceTrainingAllowed && TrainerFactory.TrainerType.SEQUENCE_TRAINER.equals(trainerType)) { throw new ResourceInitializationException(new Exception("Sequence training is not supported!")); } } else { params = TrainingParameters.defaultParams(); } return params; } }
@Override public void init(TrainingParameters trainingParameters, Map<String, String> reportMap) { super.init(trainingParameters,reportMap); this.m = trainingParameters.getIntParameter(M_PARAM, M_DEFAULT); this.maxFctEval = trainingParameters.getIntParameter(MAX_FCT_EVAL_PARAM, MAX_FCT_EVAL_DEFAULT); this.threads = trainingParameters.getIntParameter(THREADS_PARAM, THREADS_DEFAULT); this.l1Cost = trainingParameters.getDoubleParameter(L1COST_PARAM, L1COST_DEFAULT); this.l2Cost = trainingParameters.getDoubleParameter(L2COST_PARAM, L2COST_DEFAULT); }
public static boolean isValid(TrainingParameters trainParams) { // TODO: Need to validate all parameters correctly ... error prone?! String algorithmName = trainParams.getStringParameter(AbstractTrainer.ALGORITHM_PARAM,null); // If a trainer type can be determined, then the trainer is valid! if (algorithmName != null && !(BUILTIN_TRAINERS.containsKey(algorithmName) || getTrainerType(trainParams) != null)) { return false; } try { // require that the Cutoff and the number of iterations be an integer. // if they are not set, the default values will be ok. trainParams.getIntParameter(AbstractTrainer.CUTOFF_PARAM, 0); trainParams.getIntParameter(AbstractTrainer.ITERATIONS_PARAM, 0); } catch (NumberFormatException e) { return false; } // no reason to require that the dataIndexer be a 1-pass or 2-pass dataindexer. trainParams.getStringParameter(AbstractEventTrainer.DATA_INDEXER_PARAM, null); // TODO: Check data indexing ... return true; }
if (trainingParameters.getDoubleParameter(OLD_LL_THRESHOLD_PARAM, -1.) > 0. ) { display("WARNING: the training parameter: " + OLD_LL_THRESHOLD_PARAM + " has been deprecated. Please use " + LOG_LIKELIHOOD_THRESHOLD_DEFAULT + " instead"); if (trainingParameters.getDoubleParameter(LOG_LIKELIHOOD_THRESHOLD_PARAM, -1.) < 0. ) { trainingParameters.put(LOG_LIKELIHOOD_THRESHOLD_PARAM, trainingParameters.getDoubleParameter(OLD_LL_THRESHOLD_PARAM, LOG_LIKELIHOOD_THRESHOLD_DEFAULT)); llThreshold = trainingParameters.getDoubleParameter(LOG_LIKELIHOOD_THRESHOLD_PARAM, LOG_LIKELIHOOD_THRESHOLD_DEFAULT); useSimpleSmoothing = trainingParameters.getBooleanParameter(SMOOTHING_PARAM, SMOOTHING_DEFAULT); if (useSimpleSmoothing) { _smoothingObservation = trainingParameters.getDoubleParameter(SMOOTHING_OBSERVATION_PARAM, SMOOTHING_OBSERVATION); trainingParameters.getBooleanParameter(GAUSSIAN_SMOOTHING_PARAM, GAUSSIAN_SMOOTHING_DEFAULT); if (useGaussianSmoothing) { sigma = trainingParameters.getDoubleParameter( GAUSSIAN_SMOOTHING_SIGMA_PARAM, GAUSSIAN_SMOOTHING_SIGMA_DEFAULT);
ObjectStream<Event> eventStream = createEventStream(); TrainingParameters parameters = TrainingParameters.defaultParams(); parameters.put(TrainingParameters.ITERATIONS_PARAM, 10); parameters.put(AbstractEventTrainer.DATA_INDEXER_PARAM, AbstractEventTrainer.DATA_INDEXER_ONE_PASS_VALUE); parameters.put(AbstractEventTrainer.CUTOFF_PARAM, 1); parameters.put(AbstractDataIndexer.SORT_PARAM, true); parameters.put(TrainingParameters.ALGORITHM_PARAM, QNTrainer.MAXENT_QN_VALUE); parameters.put(AbstractEventTrainer.DATA_INDEXER_PARAM, AbstractEventTrainer.DATA_INDEXER_TWO_PASS_VALUE); parameters.put(AbstractEventTrainer.CUTOFF_PARAM, 2);
@Test public void testPutGet() { TrainingParameters tp = build("k1=v1,int.k2=123,str.k2=v3,str.k3=v4,boolean.k4=false,double.k5=123.45,k21=234.5"); Assert.assertEquals("v1", tp.getStringParameter("k1", "def")); Assert.assertEquals("def", tp.getStringParameter("k2", "def")); Assert.assertEquals("v3", tp.getStringParameter("str", "k2", "def")); Assert.assertEquals("def", tp.getStringParameter("str", "k4", "def")); Assert.assertEquals(-100, tp.getIntParameter("k11", -100)); tp.put("k11", 234); Assert.assertEquals(234, tp.getIntParameter("k11", -100)); Assert.assertEquals(123, tp.getIntParameter("int", "k2", -100)); Assert.assertEquals(-100, tp.getIntParameter("int", "k4", -100)); Assert.assertEquals(234.5, tp.getDoubleParameter("k21", -100), 0.001); tp.put("k21", 345.6); Assert.assertEquals(345.6, tp.getDoubleParameter("k21", -100), 0.001); // should be changed tp.putIfAbsent("k21", 456.7); Assert.assertEquals(345.6, tp.getDoubleParameter("k21", -100), 0.001); // should be unchanged Assert.assertEquals(123.45, tp.getDoubleParameter("double", "k5", -100), 0.001); Assert.assertEquals(true, tp.getBooleanParameter("k31", true)); tp.put("k31", false); Assert.assertEquals(false, tp.getBooleanParameter("k31", true)); Assert.assertEquals(false, tp.getBooleanParameter("boolean", "k4", true)); }
Map<String, String> buildReportMap = new HashMap<>(); EventTrainer buildTrainer = TrainerFactory.getEventTrainer(mlParams.getParameters("build"), buildReportMap); MaxentModel buildModel = buildTrainer.train(bes); mergeReportIntoManifest(manifestInfoEntries, buildReportMap, "build"); TrainingParameters posTaggerParams = mlParams.getParameters("tagger"); if (!posTaggerParams.getObjectSettings().containsKey(BeamSearch.BEAM_SIZE_PARAMETER)) { mlParams.put("tagger", BeamSearch.BEAM_SIZE_PARAMETER, 10); mlParams.getParameters("tagger"), new POSTaggerFactory()); new ChunkSampleStream(parseSamples), mlParams.getParameters("chunker"), new ParserChunkerFactory()); Map<String, String> checkReportMap = new HashMap<>(); EventTrainer checkTrainer = TrainerFactory.getEventTrainer(mlParams.getParameters("check"), checkReportMap); MaxentModel checkModel = checkTrainer.train(kes); mergeReportIntoManifest(manifestInfoEntries, checkReportMap, "check");
@Test public void testDefault() { TrainingParameters tr = TrainingParameters.defaultParams(); Assert.assertEquals(4, tr.getSettings().size()); Assert.assertEquals("MAXENT", tr.algorithm()); Assert.assertEquals(EventTrainer.EVENT_VALUE, tr.getStringParameter(TrainingParameters.TRAINER_TYPE_PARAM, "v11")); // use different defaults Assert.assertEquals(100, tr.getIntParameter(TrainingParameters.ITERATIONS_PARAM, 200)); // use different defaults Assert.assertEquals(5, tr.getIntParameter(TrainingParameters.CUTOFF_PARAM, 200)); // use different defaults }
public DataIndexer getDataIndexer(ObjectStream<Event> events) throws IOException { trainingParameters.put(AbstractDataIndexer.SORT_PARAM, isSortAndMerge()); // If the cutoff was set, don't overwrite the value. if (trainingParameters.getIntParameter(CUTOFF_PARAM, -1) == -1) { trainingParameters.put(CUTOFF_PARAM, 5); } DataIndexer indexer = DataIndexerFactory.getDataIndexer(trainingParameters, reportMap); indexer.index(events); return indexer; }
private static SentenceModel train(SentenceDetectorFactory factory) throws IOException { return SentenceDetectorME.train("eng", createSampleStream(), factory, TrainingParameters.defaultParams()); }
@Test public void testGetParameters() { TrainingParameters tp = build("k1=v1,n1.k2=v2,n2.k3=v3,n1.k4=v4"); assertEquals(build("k1=v1"), tp.getParameters(null)); assertEquals(build("k2=v2,k4=v4"), tp.getParameters("n1")); assertEquals(build("k3=v3"), tp.getParameters("n2")); Assert.assertTrue(tp.getParameters("n3").getSettings().isEmpty()); }
@Test public void testConstructors() throws Exception { TrainingParameters tp1 = new TrainingParameters(build("key1=val1,key2=val2,key3=val3")); TrainingParameters tp2 = new TrainingParameters( new ByteArrayInputStream("key1=val1\nkey2=val2\nkey3=val3\n".getBytes()) ); TrainingParameters tp3 = new TrainingParameters(tp2); assertEquals(tp1, tp2); assertEquals(tp2, tp3); }
/** * get a Boolean parameter * @param key * @param defaultValue * @return */ public boolean getBooleanParameter(String key, boolean defaultValue) { return getBooleanParameter(null, key, defaultValue); }
public AbstractModel doTrain(DataIndexer indexer) throws IOException { int iterations = getIterations(); int cutoff = getCutoff(); AbstractModel model; boolean useAverage = trainingParameters.getBooleanParameter("UseAverage", true); boolean useSkippedAveraging = trainingParameters.getBooleanParameter("UseSkippedAveraging", false); // overwrite otherwise it might not work if (useSkippedAveraging) useAverage = true; double stepSizeDecrease = trainingParameters.getDoubleParameter("StepSizeDecrease", 0); double tolerance = trainingParameters.getDoubleParameter("Tolerance", PerceptronTrainer.TOLERANCE_DEFAULT); this.setSkippedAveraging(useSkippedAveraging); if (stepSizeDecrease > 0) this.setStepSizeDecrease(stepSizeDecrease); this.setTolerance(tolerance); model = this.trainModel(iterations, indexer, cutoff, useAverage); return model; }
/** * get an Integer parameter * @param key * @param defaultValue * @return */ public int getIntParameter(String key, int defaultValue) { return getIntParameter(null, key, defaultValue); }
@Override public void index(ObjectStream<Event> eventStream) throws IOException { int cutoff = trainingParameters.getIntParameter(CUTOFF_PARAM, CUTOFF_DEFAULT); boolean sort = trainingParameters.getBooleanParameter(SORT_PARAM, SORT_DEFAULT); long start = System.currentTimeMillis(); display("Indexing events with OnePass using cutoff of " + cutoff + "\n\n"); display("\tComputing event counts... "); Map<String, Integer> predicateIndex = new HashMap<>(); List<Event> events = computeEventCounts(eventStream, predicateIndex, cutoff); display("done. " + events.size() + " events\n"); display("\tIndexing... "); List<ComparableEvent> eventsToCompare = index(ObjectStreamUtils.createObjectStream(events), predicateIndex); display("done.\n"); display("Sorting and merging events... "); sortAndMerge(eventsToCompare, sort); display(String.format("Done indexing in %.2f s.\n", (System.currentTimeMillis() - start) / 1000d)); }
/** * get a Double parameter * @param key * @param defaultValue * @return */ public double getDoubleParameter(String key, double defaultValue) { return getDoubleParameter(null, key, defaultValue); }
parseSamples), mlParams.getParameters("tagger"), new POSTaggerFactory()); parseSamples), mlParams.getParameters("chunker"), new ParserChunkerFactory()); mlParams.getParameters("build"), buildReportMap); MaxentModel buildModel = buildTrainer.train(bes); opennlp.tools.parser.chunking.Parser.mergeReportIntoManifest( mlParams.getParameters("check"), checkReportMap); MaxentModel checkModel = checkTrainer.train(kes); opennlp.tools.parser.chunking.Parser.mergeReportIntoManifest( Map<String, String> attachReportMap = new HashMap<>(); EventTrainer attachTrainer = TrainerFactory.getEventTrainer( mlParams.getParameters("attach"), attachReportMap); MaxentModel attachModel = attachTrainer.train(attachEvents); opennlp.tools.parser.chunking.Parser.mergeReportIntoManifest(