/** * Define a single column as an output column, all others as inputs. * @param outputColumn The output column. */ public void defineSingleOutputOthersInput(ColumnDefinition outputColumn) { this.helper.clearInputOutput(); for (ColumnDefinition colDef : this.helper.getSourceColumns()) { if (colDef == outputColumn) { defineOutput(colDef); } else if (colDef.getDataType() != ColumnType.ignore) { defineInput(colDef); } } }
/** * {@inheritDoc} */ @Override public int determineOutputCount(VersatileMLDataSet dataset) { return dataset.getNormHelper().calculateNormalizedOutputCount(); } }
/** * Divide, and optionally shuffle, the dataset. * @param dataDivisionList The desired divisions. * @param shuffle True, if we should shuffle. * @param rnd Random number generator, often with a specific seed. */ public void divide(List<DataDivision> dataDivisionList, boolean shuffle, GenerateRandom rnd) { if (getData() == null) { throw new EncogError( "Can't divide, data has not yet been generated/normalized."); } PerformDataDivision divide = new PerformDataDivision(shuffle, rnd); divide.perform(dataDivisionList, this, getCalculatedInputSize(), getCalculatedIdealSize()); }
@Override protected void doTrain() { VersatileMLDataSet data = new VersatileMLDataSet(new VersatileDataSource() { int idx = 0; offsets.entrySet().stream().sorted(Comparator.comparingInt(Map.Entry::getValue)).forEach(e -> { String k = e.getKey(); ColumnDefinition col = data.defineSourceColumn(k, offsets.get(k), typeOf(types.get(k))); // todo has bug, doesn't work like that, cols have to be in index order if (k.equals(output)) { data.defineOutput(col); } else { data.defineInput(col); data.analyze(); data.normalize();
setCalculatedIdealSize(normalizedOutputColumns); setCalculatedInputSize(normalizedInputColumns); setData(new double[this.analyzedRows][normalizedColumns]); int column = 0; for (ColumnDefinition colDef : this.helper.getInputColumns()) { int index = findIndex(colDef); String value = line[index]; getData()[row], true, value); int index = findIndex(colDef); String value = line[index]; getData()[row], false, value);
/** * Specify a validation set to hold back. * @param validationPercent The percent to use for validation. * @param shuffle True to shuffle. * @param seed The seed for random generation. */ public void holdBackValidation(double validationPercent, boolean shuffle, int seed) { List<DataDivision> dataDivisionList = new ArrayList<DataDivision>(); dataDivisionList.add(new DataDivision(1.0 - validationPercent));// Training dataDivisionList.add(new DataDivision(validationPercent));// Validation this.dataset.divide(dataDivisionList, shuffle, new MersenneTwisterGenerateRandom(seed)); this.trainingDataset = dataDivisionList.get(0).getDataset(); this.validationDataset = dataDivisionList.get(1).getDataset(); }
for (int i = 0; i < this.helper.getSourceColumns().size(); i++) { ColumnDefinition colDef = this.helper.getSourceColumns().get(i); int index = findIndex(colDef); String value = line[index]; colDef.analyze(value);
/** * {@inheritDoc} */ @Override public int determineOutputCount(VersatileMLDataSet dataset) { return dataset.getNormHelper().calculateNormalizedOutputCount(); } }
/** * Define multiple output columns, all others as inputs. * @param outputColumns The output columns. */ public void defineMultipleOutputsOthersInput(ColumnDefinition[] outputColumns) { this.helper.clearInputOutput(); for (ColumnDefinition colDef : this.helper.getSourceColumns()) { boolean isOutput = false; for(ColumnDefinition col : outputColumns) { if( col==colDef) { isOutput = true; } } if ( isOutput) { defineOutput(colDef); } else if (colDef.getDataType() != ColumnType.ignore) { defineInput(colDef); } } }
/** * {@inheritDoc} */ @Override public String suggestModelArchitecture(VersatileMLDataSet dataset) { int inputColumns = dataset.getNormHelper().getInputColumns().size(); int outputColumns = dataset.getNormHelper().getOutputColumns().size(); int hiddenCount = (int) ((double)(inputColumns+outputColumns) * 1.5); StringBuilder result = new StringBuilder(); result.append("?->gaussian(c="); result.append(hiddenCount); result.append(")->?"); return result.toString(); }
@Override public int determineOutputCount(VersatileMLDataSet dataset) { return dataset.getNormHelper().calculateNormalizedOutputCount(); } }
/** * {@inheritDoc} */ @Override public int determineOutputCount(VersatileMLDataSet dataset) { return dataset.getNormHelper().calculateNormalizedOutputCount(); } }
/** * {@inheritDoc} */ @Override public String suggestModelArchitecture(VersatileMLDataSet dataset) { int inputColumns = dataset.getNormHelper().getInputColumns().size(); int outputColumns = dataset.getNormHelper().getOutputColumns().size(); int hiddenCount = (int) ((double)(inputColumns+outputColumns) * 1.5); StringBuilder result = new StringBuilder(); result.append("?:B->TANH->"); result.append(hiddenCount); result.append(":B->TANH->?"); return result.toString(); }
/** * {@inheritDoc} */ @Override public int determineOutputCount(VersatileMLDataSet dataset) { return dataset.getNormHelper().getOutputColumns().get(0).getClasses().size(); } }
/** * {@inheritDoc} */ @Override public String suggestModelArchitecture(VersatileMLDataSet dataset) { int outputColumns = dataset.getNormHelper().getOutputColumns().size(); if( outputColumns>1 ) { throw new EncogError("SVM does not support multiple output columns."); } ColumnType ct = dataset.getNormHelper().getOutputColumns().get(0).getDataType(); StringBuilder result = new StringBuilder(); result.append("?->"); if( ct==ColumnType.nominal ) { result.append("C"); } else { result.append("R"); } result.append("->?"); return result.toString(); }
/** * Calculate the error for the given method and dataset. * @param method The method to use. * @param data The data to use. * @return The error. */ public double calculateError(MLMethod method, MLDataSet data) { if (this.dataset.getNormHelper().getOutputColumns().size() == 1) { ColumnDefinition cd = this.dataset.getNormHelper() .getOutputColumns().get(0); if (cd.getDataType() == ColumnType.nominal) { return EncogUtility.calculateClassificationError( (MLClassification) method, data); } } return EncogUtility.calculateRegressionError((MLRegression) method, data); }
/** * Select the method to use. * @param dataset The dataset. * @param methodType The type of method. * @param methodArgs The method arguments. * @param trainingType The training type. * @param trainingArgs The training arguments. */ public void selectMethod(VersatileMLDataSet dataset, String methodType, String methodArgs, String trainingType, String trainingArgs) { if (!this.methodConfigurations.containsKey(methodType)) { throw new EncogError("Don't know how to autoconfig method: " + methodType); } this.methodType = methodType; this.methodArgs = methodArgs; this.config = this.methodConfigurations.get(methodType); dataset.getNormHelper().setStrategy( this.methodConfigurations.get(methodType) .suggestNormalizationStrategy(dataset, methodArgs)); }
/** * Select the method to create. * @param dataset The dataset. * @param methodType The method type. */ public void selectMethod(VersatileMLDataSet dataset, String methodType) { if (!this.methodConfigurations.containsKey(methodType)) { throw new EncogError("Don't know how to autoconfig method: " + methodType); } this.config = this.methodConfigurations.get(methodType); this.methodType = methodType; this.methodArgs = this.config.suggestModelArchitecture(dataset); dataset.getNormHelper().setStrategy( this.config.suggestNormalizationStrategy(dataset, methodArgs)); }
@Override protected Object doPredict(String[] line) { NormalizationHelper helper = model.getDataset().getNormHelper(); MLData input = helper.allocateInputVector(); helper.normalizeInputVector(line, input.getData(), false); MLData output = method.compute(input); DataType outputType = types.get(this.output); switch (outputType) { case _float : return output.getData(0); case _class: return helper.denormalizeOutputVectorToString(output)[0]; default: throw new IllegalArgumentException("Output type not yet supported "+outputType); } }
@Override protected ML.ModelResult resultWithInfo(ML.ModelResult result) { return result.withInfo( "trainingError", EncogUtility.calculateRegressionError(method, model.getTrainingDataset()), "validationError",EncogUtility.calculateRegressionError(method, model.getValidationDataset()), "selectedMethod",method.toString(), "normalization",model.getDataset().getNormHelper().toString() ); } }