Skip to content
Open
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package com.medallia.word2vec.neuralnetwork;

import com.google.common.collect.Iterables;
import com.google.common.collect.Iterators;
import com.google.common.collect.Lists;
import com.google.common.collect.Multiset;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
Expand All @@ -15,8 +17,11 @@
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

/** Parent class for training word2vec's neural network */
Expand Down Expand Up @@ -51,7 +56,7 @@ public abstract class NeuralNetworkTrainer {
/**
* In the C version, this includes the </s> token that replaces a newline character
*/
int numTrainedTokens;
long numTrainedTokens;

/* The following includes shared state that is updated per worker thread */

Expand Down Expand Up @@ -151,28 +156,54 @@ public interface NeuralNetworkModel {
}

/** @return Trained NN model */
public NeuralNetworkModel train(Iterable<List<String>> sentences) throws InterruptedException {
ListeningExecutorService ex = MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(config.numThreads));

public NeuralNetworkModel train(final Iterable<List<String>> sentences) throws InterruptedException {
final ListeningExecutorService ex =
MoreExecutors.listeningDecorator(
new ThreadPoolExecutor(config.numThreads, config.numThreads,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neat trick, but let's leave a comment to what we're trying to accomplish here. If I understand correctly, the overall idea is to have executor.submit block if there are no available threads to avoid materializing the sentences in memory before they are needed. The ArrayBlockingQueue and CallerRunsPolicy is one way to accomplish this.

Any reason why the blocking queue starts with size 8 instead of config.numThreads?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's correct. I'll add a comment.

The queue size could be config.numThreads, but it's not really connected to the number of processors. It's just connected to the amount of overhead there is in creating these threads. In principle, a queue size of 1 should do, but I tried that and it was slower. I'm worried that if I set it to the number of processors, I'll run out of memory on a machine with lots of cores.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, that's not correct. The queue size matters when the main thread is running a task due to the CallerRunsPolicy. So it is connected to the number of processors. I changed it.

0L, TimeUnit.MILLISECONDS,
new ArrayBlockingQueue<Runnable>(8),
new ThreadPoolExecutor.CallerRunsPolicy()));

int numSentences = Iterables.size(sentences);
numTrainedTokens += numSentences;

// Partition the sentences evenly amongst the threads
Iterable<List<List<String>>> partitioned = Iterables.partition(sentences, numSentences / config.numThreads + 1);


// Partition the sentences into batches
final Iterable<List<List<String>>> batched = new Iterable<List<List<String>>>() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't remember why I didn't just use Iterables.partition() here. I'm trying to go back to that to see what the problem is.

@Override public Iterator<List<List<String>>> iterator() {
return new Iterator<List<List<String>>>() {
private final Iterator<List<String>> inner = sentences.iterator();

@Override
public boolean hasNext() {
return inner.hasNext();
}

@Override
public List<List<String>> next() {
if(!hasNext())
throw new NoSuchElementException();

return Lists.newArrayList(Iterators.limit(inner, 1024));
}

@Override
public void remove() {
throw new UnsupportedOperationException();
}
};
}
};

try {
listener.update(Stage.TRAIN_NEURAL_NETWORK, 0.0);
for (int iter = config.iterations; iter > 0; iter--) {
List<CallableVoid> tasks = new ArrayList<>();
List<ListenableFuture<?>> futures = new ArrayList<>(64);
int i = 0;
for (final List<List<String>> batch : partitioned) {
tasks.add(createWorker(i, iter, batch));
for (final List<List<String>> batch : batched) {
futures.add(ex.submit(createWorker(i, iter, batch)));
i++;
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This for loop would pull the entire training data into memory, because every worker contains a batch, and all workers are instantiated before the first one starts working.


List<ListenableFuture<?>> futures = new ArrayList<>(tasks.size());
for (CallableVoid task : tasks)
futures.add(ex.submit(task));
try {
Futures.allAsList(futures).get();
} catch (ExecutionException e) {
Expand Down