Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions src/main/java/com/medallia/word2vec/NormalizedWord2VecModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
* Represents a word2vec model where all the vectors are normalized to unit length.
*/
public class NormalizedWord2VecModel extends Word2VecModel {
private NormalizedWord2VecModel(Iterable<String> vocab, int layerSize, final DoubleBuffer vectors) {
private NormalizedWord2VecModel(Iterable<String> vocab, int layerSize, final DoubleBuffer[] vectors) {
super(vocab, layerSize, vectors);
normalize();
}
Expand All @@ -22,7 +22,11 @@ private NormalizedWord2VecModel(Iterable<String> vocab, int layerSize, double[]
}

public static NormalizedWord2VecModel fromWord2VecModel(Word2VecModel model) {
return new NormalizedWord2VecModel(model.vocab, model.layerSize, model.vectors.duplicate());
DoubleBuffer[] newVectors = new DoubleBuffer[model.vectors.length];
for (int i = 0; i < newVectors.length; i++) {
newVectors[i] = model.vectors[i].duplicate();
}
return new NormalizedWord2VecModel(model.vocab, model.layerSize, newVectors);
}

/** @return {@link NormalizedWord2VecModel} created from a thrift representation */
Expand All @@ -36,14 +40,17 @@ public static NormalizedWord2VecModel fromBinFile(final File file) throws IOExce

/** Normalizes the vectors in this model */
private void normalize() {
for(int i = 0; i < vocab.size(); ++i) {
double len = 0;
for(int j = i * layerSize; j < (i + 1) * layerSize; ++j)
len += vectors.get(j) * vectors.get(j);
len = Math.sqrt(len);

for(int j = i * layerSize; j < (i + 1) * layerSize; ++j)
vectors.put(j, vectors.get(j) / len);
for(int i = 0; i < vectors.length; ++i) {
DoubleBuffer buffer = vectors[i];
for(int j = 0; j < Math.min(vectorsPerBuffer, buffer.limit() / layerSize); ++j) {
double len = 0;
for(int k = j * layerSize; k < (j + 1) * layerSize; ++k)
len += buffer.get(k) * buffer.get(k);
len = Math.sqrt(len);

for(int k = j * layerSize; k < (j + 1) * layerSize; ++k)
buffer.put(k, buffer.get(k) / len);
}
}
}
}
51 changes: 30 additions & 21 deletions src/main/java/com/medallia/word2vec/SearcherImpl.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.medallia.word2vec;

import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
Expand All @@ -14,19 +15,27 @@
/** Implementation of {@link Searcher} */
class SearcherImpl implements Searcher {
private final NormalizedWord2VecModel model;
private final ImmutableMap<String, Integer> word2vectorOffset;

SearcherImpl(final NormalizedWord2VecModel model) {
this.model = model;

final ImmutableMap.Builder<String, Integer> result = ImmutableMap.builder();
for (int i = 0; i < model.vocab.size(); i++) {
result.put(model.vocab.get(i), i * model.layerSize);
private final ImmutableMap<String, Long> word2vectorOffset;
private final int bufferSize;

SearcherImpl(final NormalizedWord2VecModel model) {
this.bufferSize = model.layerSize * model.vectorsPerBuffer;
long maxIndex = ((long) model.vocab.size() - 1) * model.layerSize;
// We use the vocab index divided by the buffer size as an array index, so it must fit into an int.
Preconditions.checkArgument(
maxIndex / bufferSize < Integer.MAX_VALUE,
"vocabulary and / or vector size is too large to calculate indexes for"
);
this.model = model;

final ImmutableMap.Builder<String, Long> result = ImmutableMap.builder();
for (int i = 0; i < model.vocab.size(); i++) {
result.put(model.vocab.get(i), ((long) i) * model.layerSize);
}

word2vectorOffset = result.build();
}

word2vectorOffset = result.build();
}

SearcherImpl(final Word2VecModel model) {
this(NormalizedWord2VecModel.fromWord2VecModel(model));
}
Expand Down Expand Up @@ -80,17 +89,17 @@ private double[] getVector(String word) throws UnknownWordException {
return result;
}

private double[] getVectorOrNull(final String word) {
final Integer index = word2vectorOffset.get(word);
if(index == null)
return null;
private double[] getVectorOrNull(final String word) {
final Long index = word2vectorOffset.get(word);
if(index == null)
return null;

final DoubleBuffer vectors = model.vectors.duplicate();
double[] result = new double[model.layerSize];
vectors.position(index);
vectors.get(result);
return result;
}
final DoubleBuffer vectors = model.vectors[(int) (index / bufferSize)].duplicate();
double[] result = new double[model.layerSize];
vectors.position((int) (index % bufferSize));
vectors.get(result);
return result;
}

/** @return Vector difference from v1 to v2 */
private double[] getDifference(double[] v1, double[] v2) {
Expand Down
158 changes: 102 additions & 56 deletions src/main/java/com/medallia/word2vec/Word2VecModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,26 @@
public class Word2VecModel {
final List<String> vocab;
final int layerSize;
final DoubleBuffer vectors;
/** The max number of vectors stored in each DoubleBuffer. */
final int vectorsPerBuffer;
final DoubleBuffer[] vectors;
private final static long ONE_GB = 1024 * 1024 * 1024;
/**
* The maxiumum size we will build a double buffer, in doubles. Since we use
* memory-mapped byte buffers, and these have their size specified with an
* int, the most doubles we can store is Integer.MAX_VALUE / 8.
*/
private final static int MAX_DOUBLE_BUFFER = Integer.MAX_VALUE / 8;

Word2VecModel(Iterable<String> vocab, int layerSize, DoubleBuffer vectors) {
Word2VecModel(Iterable<String> vocab, int layerSize, DoubleBuffer[] vectors) {
this.vocab = ImmutableList.copyOf(vocab);
this.layerSize = layerSize;
this.vectors = vectors;
this.vectorsPerBuffer = vectors[0].limit() / layerSize;
}

Word2VecModel(Iterable<String> vocab, int layerSize, double[] vectors) {
this(vocab, layerSize, DoubleBuffer.wrap(vectors));
this(vocab, layerSize, new DoubleBuffer[] { DoubleBuffer.wrap(vectors) });
}

/** @return Vocabulary */
Expand All @@ -65,12 +74,21 @@ public Searcher forSearch() {
/** @return Serializable thrift representation */
public Word2VecModelThrift toThrift() {
double[] vectorsArray;
if(vectors.hasArray()) {
vectorsArray = vectors.array();
if(vectors.length == 1 && vectors[0].hasArray()) {
vectorsArray = vectors[0].array();
} else {
vectorsArray = new double[vectors.limit()];
vectors.position(0);
vectors.get(vectorsArray);
int totalSize = 0;
for (DoubleBuffer buffer : vectors) {
totalSize += buffer.limit();
}
vectorsArray = new double[totalSize];
int copiedCount = 0;
for (DoubleBuffer buffer : vectors) {
int size = buffer.limit();
buffer.position(0);
buffer.get(vectorsArray, copiedCount, size);
copiedCount += size;
}
}

return new Word2VecModelThrift()
Expand Down Expand Up @@ -119,6 +137,16 @@ public static Word2VecModel fromBinFile(File file, ByteOrder byteOrder)
*/
public static Word2VecModel fromBinFile(File file, ByteOrder byteOrder, ProfilingTimer timer)
throws IOException {
return fromBinFile(file, byteOrder, timer, MAX_DOUBLE_BUFFER);
}

/**
* Testable version, with injected max double buffer size.
* @return {@link Word2VecModel} created from the binary representation output
* by the open source C version of word2vec using the given byte order.
*/
public static Word2VecModel fromBinFile(File file, ByteOrder byteOrder, ProfilingTimer timer, int maxDoubleBufferSize)
throws IOException {

try (
final FileInputStream fis = new FileInputStream(file);
Expand Down Expand Up @@ -156,58 +184,72 @@ public static Word2VecModel fromBinFile(File file, ByteOrder byteOrder, Profilin
vocabSize,
layerSize));

List<String> vocabs = new ArrayList<String>(vocabSize);
DoubleBuffer vectors = ByteBuffer.allocateDirect(vocabSize * layerSize * 8).asDoubleBuffer();
// Build up enough DoubleBuffers to store all of the vectors we'll be loading.
int vectorsPerBuffer = maxDoubleBufferSize / layerSize;
int numBuffers = vocabSize / vectorsPerBuffer + (vocabSize % vectorsPerBuffer != 0 ? 1 : 0);
DoubleBuffer[] vectors = new DoubleBuffer[numBuffers];
int remainingVectors = vocabSize;
for (int i = 0; remainingVectors > vectorsPerBuffer; i++, remainingVectors -= vectorsPerBuffer) {
vectors[i] = ByteBuffer.allocateDirect(vectorsPerBuffer * layerSize * 8).asDoubleBuffer();
}
if (remainingVectors > 0) {
vectors[numBuffers - 1] = ByteBuffer.allocateDirect(remainingVectors * layerSize * 8).asDoubleBuffer();
}

List<String> vocabs = new ArrayList<String>(vocabSize);
long lastLogMessage = System.currentTimeMillis();
final float[] floats = new float[layerSize];
for (int lineno = 0; lineno < vocabSize; lineno++) {
// read vocab
sb.setLength(0);
c = (char) buffer.get();
while (c != ' ') {
// ignore newlines in front of words (some binary files have newline,
// some don't)
if (c != '\n') {
sb.append(c);
}
int lineno = 0;
for (int buffno = 0; buffno < vectors.length; buffno++) {
DoubleBuffer vectorBuffer = vectors[buffno];
for (int vecno = 0; vecno < Math.min(vectorsPerBuffer, vectorBuffer.limit()/ layerSize); vecno++, lineno++) {
// read vocab
sb.setLength(0);
c = (char) buffer.get();
}
vocabs.add(sb.toString());
while (c != ' ') {
// ignore newlines in front of words (some binary files have newline,
// some don't)
if (c != '\n') {
sb.append(c);
}
c = (char) buffer.get();
}
vocabs.add(sb.toString());

// read vector
final FloatBuffer floatBuffer = buffer.asFloatBuffer();
floatBuffer.get(floats);
for (int i = 0; i < floats.length; ++i) {
vectors.put(lineno * layerSize + i, floats[i]);
}
buffer.position(buffer.position() + 4 * layerSize);

// print log
final long now = System.currentTimeMillis();
if (now - lastLogMessage > 1000) {
final double percentage = ((double) (lineno + 1) / (double) vocabSize) * 100.0;
timer.appendToLog(
String.format("Loaded %d/%d vectors (%f%%)", lineno + 1, vocabSize, percentage));
lastLogMessage = now;
}
// read vector
final FloatBuffer floatBuffer = buffer.asFloatBuffer();
floatBuffer.get(floats);
for (int i = 0; i < floats.length; ++i) {
vectorBuffer.put(vecno * layerSize + i, floats[i]);
}
buffer.position(buffer.position() + 4 * layerSize);

// print log
final long now = System.currentTimeMillis();
if (now - lastLogMessage > 1000) {
final double percentage = ((double) (lineno + 1) / (double) vocabSize) * 100.0;
timer.appendToLog(
String.format("Loaded %d/%d vectors (%f%%)", lineno + 1, vocabSize, percentage));
lastLogMessage = now;
}

// remap file
if (buffer.position() > ONE_GB) {
final int newPosition = (int) (buffer.position() - ONE_GB);
final long size = Math.min(channel.size() - ONE_GB * bufferCount, Integer.MAX_VALUE);
timer.endAndStart(
"Reading gigabyte #%d. Start: %d, size: %d",
bufferCount,
ONE_GB * bufferCount,
size);
buffer = channel.map(
FileChannel.MapMode.READ_ONLY,
ONE_GB * bufferCount,
size);
buffer.order(byteOrder);
buffer.position(newPosition);
bufferCount += 1;
// remap file
if (buffer.position() > ONE_GB) {
final int newPosition = (int) (buffer.position() - ONE_GB);
final long size = Math.min(channel.size() - ONE_GB * bufferCount, Integer.MAX_VALUE);
timer.endAndStart(
"Reading gigabyte #%d. Start: %d, size: %d",
bufferCount,
ONE_GB * bufferCount,
size);
buffer = channel.map(
FileChannel.MapMode.READ_ONLY,
ONE_GB * bufferCount,
size);
buffer.order(byteOrder);
buffer.position(newPosition);
bufferCount += 1;
}
}
}
timer.end();
Expand All @@ -227,11 +269,15 @@ public void toBinFile(final OutputStream out) throws IOException {
final double[] vector = new double[layerSize];
final ByteBuffer buffer = ByteBuffer.allocate(4 * layerSize);
buffer.order(ByteOrder.LITTLE_ENDIAN); // The C version uses this byte order.

int vectorsPerBuffer = MAX_DOUBLE_BUFFER / layerSize;

for(int i = 0; i < vocab.size(); ++i) {
out.write(String.format("%s ", vocab.get(i)).getBytes(cs));

vectors.position(i * layerSize);
vectors.get(vector);
DoubleBuffer vectorBuffer = vectors[i / vectorsPerBuffer];
vectorBuffer.position(i * layerSize);
vectorBuffer.get(vector);
buffer.clear();
for(int j = 0; j < layerSize; ++j)
buffer.putFloat((float)vector[j]);
Expand Down
23 changes: 23 additions & 0 deletions src/test/java/com/medallia/word2vec/Word2VecBinTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import java.io.File;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.ByteOrder;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.List;
Expand All @@ -16,6 +17,7 @@

import com.medallia.word2vec.Searcher.UnknownWordException;
import com.medallia.word2vec.util.Common;
import com.medallia.word2vec.util.ProfilingTimer;

/**
* Tests converting the binary models into
Expand Down Expand Up @@ -68,6 +70,27 @@ public void testRoundTrip() throws IOException, UnknownWordException {
assertEquals(model, modelCopy);
}

/**
* Tests that a Word2VecModel can load properly if the file doesn't fit into one buffer.
*/
@Test
public void testLargeFile() throws IOException, UnknownWordException {
File binFile = Common.getResourceAsFile(
this.getClass(),
"/com/medallia/word2vec/tokensModel.bin");
// The tokens model has 1186 words of 200 doubles each. Make it store 500 vectors per buffer.
Word2VecModel binModel =
Word2VecModel.fromBinFile(binFile, ByteOrder.LITTLE_ENDIAN, ProfilingTimer.NONE, 500 * 200);
Assert.assertEquals("binary file should have been split in 3", binModel.vectors.length, 3);

File txtFile = Common.getResourceAsFile(
this.getClass(),
"/com/medallia/word2vec/tokensModel.txt");
Word2VecModel txtModel = Word2VecModel.fromTextFile(txtFile);

assertEquals(binModel, txtModel);
}

@After
public void cleanupTempFile() throws IOException {
if(tempFile != null)
Expand Down