/** * Generate the masks, for all divisions. * @param dataDivisionList The divisions. */ private void generateMasks(List<DataDivision> dataDivisionList) { int idx = 0; for (DataDivision division : dataDivisionList) { division.allocateMask(division.getCount()); for (int i = 0; i < division.getCount(); i++) { division.getMask()[i] = idx++; } } }
/** * Perform the split. * @param dataDivisionList The list of data divisions. * @param dataset The dataset to split. * @param inputCount The input count. * @param idealCount The ideal count. */ public void perform(List<DataDivision> dataDivisionList, VersatileMLDataSet dataset, int inputCount, int idealCount) { generateCounts(dataDivisionList, dataset.getData().length); generateMasks(dataDivisionList); if (this.shuffle) { performShuffle(dataDivisionList, dataset.getData().length); } createDividedDatasets(dataDivisionList, dataset, inputCount, idealCount); }
/** * Generate the counts for all divisions, give remaining items to final division. * @param dataDivisionList The division list. * @param totalCount The total count. */ private void generateCounts(List<DataDivision> dataDivisionList, int totalCount) { // First pass at division. int countSofar = 0; for (DataDivision division : dataDivisionList) { int count = (int) (division.getPercent() * totalCount); division.setCount(count); countSofar += count; } // Adjust any remaining count int remaining = totalCount - countSofar; while (remaining-- > 0) { int idx = this.rnd.nextInt(dataDivisionList.size()); DataDivision div = dataDivisionList.get(idx); div.setCount(div.getCount() + 1); } }
/** * Specify a validation set to hold back. * @param validationPercent The percent to use for validation. * @param shuffle True to shuffle. * @param seed The seed for random generation. */ public void holdBackValidation(double validationPercent, boolean shuffle, int seed) { List<DataDivision> dataDivisionList = new ArrayList<DataDivision>(); dataDivisionList.add(new DataDivision(1.0 - validationPercent));// Training dataDivisionList.add(new DataDivision(validationPercent));// Validation this.dataset.divide(dataDivisionList, shuffle, new MersenneTwisterGenerateRandom(seed)); this.trainingDataset = dataDivisionList.get(0).getDataset(); this.validationDataset = dataDivisionList.get(1).getDataset(); }
/** * Divide, and optionally shuffle, the dataset. * @param dataDivisionList The desired divisions. * @param shuffle True, if we should shuffle. * @param rnd Random number generator, often with a specific seed. */ public void divide(List<DataDivision> dataDivisionList, boolean shuffle, GenerateRandom rnd) { if (getData() == null) { throw new EncogError( "Can't divide, data has not yet been generated/normalized."); } PerformDataDivision divide = new PerformDataDivision(shuffle, rnd); divide.perform(dataDivisionList, this, getCalculatedInputSize(), getCalculatedIdealSize()); }
/** * Create the datasets that we will divide into. * @param dataDivisionList The list of divisions. * @param parentDataset The data set to divide. * @param inputCount The input count. * @param idealCount The ideal count. */ private void createDividedDatasets(List<DataDivision> dataDivisionList, VersatileMLDataSet parentDataset, int inputCount, int idealCount) { for (DataDivision division : dataDivisionList) { MatrixMLDataSet dataset = new MatrixMLDataSet(parentDataset.getData(), inputCount, idealCount, division.getMask()); dataset.setLagWindowSize(parentDataset.getLagWindowSize()); dataset.setLeadWindowSize(parentDataset.getLeadWindowSize()); division.setDataset(dataset); } }
/** * Perform a Fisher-Yates shuffle. * http://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle * * @param dataDivisionList * The division list. */ private void performShuffle(List<DataDivision> dataDivisionList, int totalCount) { for (int i = totalCount - 1; i > 0; i--) { int n = this.rnd.nextInt(i + 1); virtualSwap(dataDivisionList, i, n); } }