DataSetIterator test = new ExistingMiniBatchDataSetIterator(new File(TEST_PATH)); ParallelWrapper pw = new ParallelWrapper.Builder<>(net) .prefetchBuffer(16 * Nd4j.getAffinityManager().getNumberOfDevices()) .reportScoreAfterAveraging(true) .averagingFrequency(10) .workers(Nd4j.getAffinityManager().getNumberOfDevices()) .build(); pw.fit(train); train.reset();
ParallelWrapper wrapper = new ParallelWrapper.Builder(vgg16) .prefetchBuffer(24) .workers(Nd4j.getAffinityManager().getNumberOfDevices()) .trainingMode(ParallelWrapper.TrainingMode.SHARED_GRADIENTS) .build(); wrapper.fit(trainIter); time = System.currentTimeMillis() - time; log.info("*** Completed epoch {}, time: {} ***", i, time);
ParallelWrapper wrapper = new ParallelWrapper.Builder(model) .prefetchBuffer(24) .workers(2) .workspaceMode(WorkspaceMode.SINGLE) .trainerFactory(new SymmetricTrainerContext()) .trainingMode(ParallelWrapper.TrainingMode.CUSTOM) .gradientsAccumulator(new EncodedGradientsAccumulator(2, 1e-3)) .build(); wrapper.fit(mnistTrain); long time2 = System.currentTimeMillis(); log.info("*** Completed epoch {}, time: {} ***", i, (time2 - time1));
/** * Set the listeners, along with a StatsStorageRouter that the results will be shuffled to (in the case of any listeners * that implement the {@link RoutingIterationListener} interface) * * @param statsStorage Stats storage router to place the results into * @param listeners Listeners to set */ public void setListeners(StatsStorageRouter statsStorage, IterationListener... listeners) { setListeners(statsStorage, Arrays.asList(listeners)); }
public INDArray output(INDArray input) { // basically, depending on model type we either // throw stuff to specific model, or wait for batch return output(new INDArray[] {input})[0]; }
protected long getWorkerCounter(int workerIdx) { return zoo[workerIdx].getCounterValue(); }
@Override public boolean removeAll(Collection<?> c) { for (Object o : c) remove(o); return true; }
@Override public boolean addAll(Collection<? extends T> c) { for (T ds : c) { boolean result = add(ds); if (!result) return result; } return true; }
@Override public boolean containsAll(Collection<?> c) { for (Object o : c) { if (!contains(o)) return false; } return true; }
@Override public boolean addAll(Collection<? extends E> c) { for (E e : c) add(e); return true; }
@Override public boolean isEmpty() { return size() < 1; }
log.info(transferLearningHelper.unfrozenGraph().summary()); ParallelWrapper wrapper = new ParallelWrapper.Builder(transferLearningHelper.unfrozenGraph()) .prefetchBuffer(24) .workers(Nd4j.getAffinityManager().getNumberOfDevices()) .averagingFrequency(3) .reportScoreAfterAveraging(true) .build(); wrapper.fit(trainIter); trainIter.reset(); log.info("Epoch #" + epoch +" complete");
/** * Set the listeners, along with a StatsStorageRouter that the results will be shuffled to (in the case of any listeners * that implement the {@link RoutingIterationListener} interface) * * @param statsStorage Stats storage router to place the results into * @param listeners Listeners to set */ public void setListeners(StatsStorageRouter statsStorage, IterationListener... listeners) { setListeners(statsStorage, Arrays.asList(listeners)); }
public INDArray output(INDArray input) { // basically, depending on model type we either // throw stuff to specific model, or wait for batch return output(new INDArray[] {input})[0]; }
protected long getWorkerCounter(int workerIdx) { return zoo[workerIdx].getCounterValue(); }
ParallelWrapper wrapper = new ParallelWrapper.Builder(net) .prefetchBuffer(24) .workers(4) .averagingFrequency(3) .reportScoreAfterAveraging(true) .build(); wrapper.fit(iter);
/** * * @param input * @return */ public INDArray output(float[] input) { return output(Nd4j.create(input)); }
ParallelWrapper wrapper = new ParallelWrapper.Builder(model) .prefetchBuffer(24) .workers(4) .averagingFrequency(3) .reportScoreAfterAveraging(true) .build(); wrapper.fit(mnistTrain); long time2 = System.currentTimeMillis(); log.info("*** Completed epoch {}, time: {} ***", i, (time2 - time1));
ParallelWrapper wrapper = new ParallelWrapper.Builder(model) .prefetchBuffer(24) .workers(2) .averagingFrequency(3) .reportScoreAfterAveraging(true) .build(); wrapper.fit(mnistTrain); long time2 = System.currentTimeMillis(); log.info("*** Completed epoch {}, time: {} ***", i, (time2 - time1));
ParallelWrapper wrapper = new ParallelWrapper.Builder(net) .prefetchBuffer(24) .workers(8) .averagingFrequency(3) .reportScoreAfterAveraging(true) .build(); for (int i = 0; i < nTrainEpochs; i++) { DataSetIterator trainData = getDataSetIterator(dataDirectory, 0, testStartIdx - 1, miniBatchSize); wrapper.fit(trainData);