diff --git a/src/main/java/com/medallia/word2vec/neuralnetwork/NeuralNetworkTrainer.java b/src/main/java/com/medallia/word2vec/neuralnetwork/NeuralNetworkTrainer.java index a415c53..99e72e0 100644 --- a/src/main/java/com/medallia/word2vec/neuralnetwork/NeuralNetworkTrainer.java +++ b/src/main/java/com/medallia/word2vec/neuralnetwork/NeuralNetworkTrainer.java @@ -1,6 +1,8 @@ package com.medallia.word2vec.neuralnetwork; import com.google.common.collect.Iterables; +import com.google.common.collect.Iterators; +import com.google.common.collect.Lists; import com.google.common.collect.Multiset; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; @@ -15,8 +17,11 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.NoSuchElementException; +import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.ExecutionException; -import java.util.concurrent.Executors; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; /** Parent class for training word2vec's neural network */ @@ -51,7 +56,7 @@ public abstract class NeuralNetworkTrainer { /** * In the C version, this includes the token that replaces a newline character */ - int numTrainedTokens; + long numTrainedTokens; /* The following includes shared state that is updated per worker thread */ @@ -151,28 +156,33 @@ public interface NeuralNetworkModel { } /** @return Trained NN model */ - public NeuralNetworkModel train(Iterable> sentences) throws InterruptedException { - ListeningExecutorService ex = MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(config.numThreads)); - + public NeuralNetworkModel train(final Iterable> sentences) throws InterruptedException { + // Create an executor that runs as many threads as are defined in the config, and blocks if + // you're trying to run more. This is to make sure we don't read the entire corpus into + // memory. + final ListeningExecutorService ex = + MoreExecutors.listeningDecorator( + new ThreadPoolExecutor(config.numThreads, config.numThreads, + 0L, TimeUnit.MILLISECONDS, + new ArrayBlockingQueue(config.numThreads), + new ThreadPoolExecutor.CallerRunsPolicy())); + int numSentences = Iterables.size(sentences); numTrainedTokens += numSentences; - - // Partition the sentences evenly amongst the threads - Iterable>> partitioned = Iterables.partition(sentences, numSentences / config.numThreads + 1); - + + // Partition the sentences into batches + final Iterable>> batched = Iterables.partition(sentences, 1024); + try { listener.update(Stage.TRAIN_NEURAL_NETWORK, 0.0); for (int iter = config.iterations; iter > 0; iter--) { - List tasks = new ArrayList<>(); + List> futures = new ArrayList<>(64); int i = 0; - for (final List> batch : partitioned) { - tasks.add(createWorker(i, iter, batch)); + for (final List> batch : batched) { + futures.add(ex.submit(createWorker(i, iter, batch))); i++; } - List> futures = new ArrayList<>(tasks.size()); - for (CallableVoid task : tasks) - futures.add(ex.submit(task)); try { Futures.allAsList(futures).get(); } catch (ExecutionException e) {