-
Notifications
You must be signed in to change notification settings - Fork 82
Makes sure we don't pull the whole corpus into memory when training #23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 4 commits
fe6d329
8e81df1
e5f7d44
94db87d
61e7fc7
464abcc
8cf84b0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,8 @@ | ||
| package com.medallia.word2vec.neuralnetwork; | ||
|
|
||
| import com.google.common.collect.Iterables; | ||
| import com.google.common.collect.Iterators; | ||
| import com.google.common.collect.Lists; | ||
| import com.google.common.collect.Multiset; | ||
| import com.google.common.util.concurrent.Futures; | ||
| import com.google.common.util.concurrent.ListenableFuture; | ||
|
|
@@ -15,8 +17,11 @@ | |
| import java.util.Iterator; | ||
| import java.util.List; | ||
| import java.util.Map; | ||
| import java.util.NoSuchElementException; | ||
| import java.util.concurrent.ArrayBlockingQueue; | ||
| import java.util.concurrent.ExecutionException; | ||
| import java.util.concurrent.Executors; | ||
| import java.util.concurrent.ThreadPoolExecutor; | ||
| import java.util.concurrent.TimeUnit; | ||
| import java.util.concurrent.atomic.AtomicInteger; | ||
|
|
||
| /** Parent class for training word2vec's neural network */ | ||
|
|
@@ -51,7 +56,7 @@ public abstract class NeuralNetworkTrainer { | |
| /** | ||
| * In the C version, this includes the </s> token that replaces a newline character | ||
| */ | ||
| int numTrainedTokens; | ||
| long numTrainedTokens; | ||
|
|
||
| /* The following includes shared state that is updated per worker thread */ | ||
|
|
||
|
|
@@ -151,28 +156,54 @@ public interface NeuralNetworkModel { | |
| } | ||
|
|
||
| /** @return Trained NN model */ | ||
| public NeuralNetworkModel train(Iterable<List<String>> sentences) throws InterruptedException { | ||
| ListeningExecutorService ex = MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(config.numThreads)); | ||
|
|
||
| public NeuralNetworkModel train(final Iterable<List<String>> sentences) throws InterruptedException { | ||
| final ListeningExecutorService ex = | ||
| MoreExecutors.listeningDecorator( | ||
| new ThreadPoolExecutor(config.numThreads, config.numThreads, | ||
| 0L, TimeUnit.MILLISECONDS, | ||
| new ArrayBlockingQueue<Runnable>(8), | ||
| new ThreadPoolExecutor.CallerRunsPolicy())); | ||
|
|
||
| int numSentences = Iterables.size(sentences); | ||
| numTrainedTokens += numSentences; | ||
|
|
||
| // Partition the sentences evenly amongst the threads | ||
| Iterable<List<List<String>>> partitioned = Iterables.partition(sentences, numSentences / config.numThreads + 1); | ||
|
|
||
|
|
||
| // Partition the sentences into batches | ||
| final Iterable<List<List<String>>> batched = new Iterable<List<List<String>>>() { | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't remember why I didn't just use |
||
| @Override public Iterator<List<List<String>>> iterator() { | ||
| return new Iterator<List<List<String>>>() { | ||
| private final Iterator<List<String>> inner = sentences.iterator(); | ||
|
|
||
| @Override | ||
| public boolean hasNext() { | ||
| return inner.hasNext(); | ||
| } | ||
|
|
||
| @Override | ||
| public List<List<String>> next() { | ||
| if(!hasNext()) | ||
| throw new NoSuchElementException(); | ||
|
|
||
| return Lists.newArrayList(Iterators.limit(inner, 1024)); | ||
| } | ||
|
|
||
| @Override | ||
| public void remove() { | ||
| throw new UnsupportedOperationException(); | ||
| } | ||
| }; | ||
| } | ||
| }; | ||
|
|
||
| try { | ||
| listener.update(Stage.TRAIN_NEURAL_NETWORK, 0.0); | ||
| for (int iter = config.iterations; iter > 0; iter--) { | ||
| List<CallableVoid> tasks = new ArrayList<>(); | ||
| List<ListenableFuture<?>> futures = new ArrayList<>(64); | ||
| int i = 0; | ||
| for (final List<List<String>> batch : partitioned) { | ||
| tasks.add(createWorker(i, iter, batch)); | ||
| for (final List<List<String>> batch : batched) { | ||
| futures.add(ex.submit(createWorker(i, iter, batch))); | ||
| i++; | ||
| } | ||
|
||
|
|
||
| List<ListenableFuture<?>> futures = new ArrayList<>(tasks.size()); | ||
| for (CallableVoid task : tasks) | ||
| futures.add(ex.submit(task)); | ||
| try { | ||
| Futures.allAsList(futures).get(); | ||
| } catch (ExecutionException e) { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Neat trick, but let's leave a comment to what we're trying to accomplish here. If I understand correctly, the overall idea is to have executor.submit block if there are no available threads to avoid materializing the sentences in memory before they are needed. The ArrayBlockingQueue and CallerRunsPolicy is one way to accomplish this.
Any reason why the blocking queue starts with size 8 instead of config.numThreads?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's correct. I'll add a comment.
The queue size could be
config.numThreads, but it's not really connected to the number of processors. It's just connected to the amount of overhead there is in creating these threads. In principle, a queue size of 1 should do, but I tried that and it was slower. I'm worried that if I set it to the number of processors, I'll run out of memory on a machine with lots of cores.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, that's not correct. The queue size matters when the main thread is running a task due to the
CallerRunsPolicy. So it is connected to the number of processors. I changed it.