diff --git a/src/main/java/com/medallia/word2vec/NormalizedWord2VecModel.java b/src/main/java/com/medallia/word2vec/NormalizedWord2VecModel.java index a869d76..c2780ce 100644 --- a/src/main/java/com/medallia/word2vec/NormalizedWord2VecModel.java +++ b/src/main/java/com/medallia/word2vec/NormalizedWord2VecModel.java @@ -11,7 +11,7 @@ * Represents a word2vec model where all the vectors are normalized to unit length. */ public class NormalizedWord2VecModel extends Word2VecModel { - private NormalizedWord2VecModel(Iterable vocab, int layerSize, final DoubleBuffer vectors) { + private NormalizedWord2VecModel(Iterable vocab, int layerSize, final DoubleBuffer[] vectors) { super(vocab, layerSize, vectors); normalize(); } @@ -22,7 +22,11 @@ private NormalizedWord2VecModel(Iterable vocab, int layerSize, double[] } public static NormalizedWord2VecModel fromWord2VecModel(Word2VecModel model) { - return new NormalizedWord2VecModel(model.vocab, model.layerSize, model.vectors.duplicate()); + DoubleBuffer[] newVectors = new DoubleBuffer[model.vectors.length]; + for (int i = 0; i < newVectors.length; i++) { + newVectors[i] = model.vectors[i].duplicate(); + } + return new NormalizedWord2VecModel(model.vocab, model.layerSize, newVectors); } /** @return {@link NormalizedWord2VecModel} created from a thrift representation */ @@ -36,14 +40,17 @@ public static NormalizedWord2VecModel fromBinFile(final File file) throws IOExce /** Normalizes the vectors in this model */ private void normalize() { - for(int i = 0; i < vocab.size(); ++i) { - double len = 0; - for(int j = i * layerSize; j < (i + 1) * layerSize; ++j) - len += vectors.get(j) * vectors.get(j); - len = Math.sqrt(len); - - for(int j = i * layerSize; j < (i + 1) * layerSize; ++j) - vectors.put(j, vectors.get(j) / len); + for(int i = 0; i < vectors.length; ++i) { + DoubleBuffer buffer = vectors[i]; + for(int j = 0; j < Math.min(vectorsPerBuffer, buffer.limit() / layerSize); ++j) { + double len = 0; + for(int k = j * layerSize; k < (j + 1) * layerSize; ++k) + len += buffer.get(k) * buffer.get(k); + len = Math.sqrt(len); + + for(int k = j * layerSize; k < (j + 1) * layerSize; ++k) + buffer.put(k, buffer.get(k) / len); + } } } } diff --git a/src/main/java/com/medallia/word2vec/SearcherImpl.java b/src/main/java/com/medallia/word2vec/SearcherImpl.java index 56d9238..03e779e 100644 --- a/src/main/java/com/medallia/word2vec/SearcherImpl.java +++ b/src/main/java/com/medallia/word2vec/SearcherImpl.java @@ -1,6 +1,7 @@ package com.medallia.word2vec; import com.google.common.base.Function; +import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; @@ -14,19 +15,27 @@ /** Implementation of {@link Searcher} */ class SearcherImpl implements Searcher { private final NormalizedWord2VecModel model; - private final ImmutableMap word2vectorOffset; - - SearcherImpl(final NormalizedWord2VecModel model) { - this.model = model; - - final ImmutableMap.Builder result = ImmutableMap.builder(); - for (int i = 0; i < model.vocab.size(); i++) { - result.put(model.vocab.get(i), i * model.layerSize); + private final ImmutableMap word2vectorOffset; + private final int bufferSize; + + SearcherImpl(final NormalizedWord2VecModel model) { + this.bufferSize = model.layerSize * model.vectorsPerBuffer; + long maxIndex = ((long) model.vocab.size() - 1) * model.layerSize; + // We use the vocab index divided by the buffer size as an array index, so it must fit into an int. + Preconditions.checkArgument( + maxIndex / bufferSize < Integer.MAX_VALUE, + "vocabulary and / or vector size is too large to calculate indexes for" + ); + this.model = model; + + final ImmutableMap.Builder result = ImmutableMap.builder(); + for (int i = 0; i < model.vocab.size(); i++) { + result.put(model.vocab.get(i), ((long) i) * model.layerSize); + } + + word2vectorOffset = result.build(); } - word2vectorOffset = result.build(); - } - SearcherImpl(final Word2VecModel model) { this(NormalizedWord2VecModel.fromWord2VecModel(model)); } @@ -80,17 +89,17 @@ private double[] getVector(String word) throws UnknownWordException { return result; } - private double[] getVectorOrNull(final String word) { - final Integer index = word2vectorOffset.get(word); - if(index == null) - return null; + private double[] getVectorOrNull(final String word) { + final Long index = word2vectorOffset.get(word); + if(index == null) + return null; - final DoubleBuffer vectors = model.vectors.duplicate(); - double[] result = new double[model.layerSize]; - vectors.position(index); - vectors.get(result); - return result; - } + final DoubleBuffer vectors = model.vectors[(int) (index / bufferSize)].duplicate(); + double[] result = new double[model.layerSize]; + vectors.position((int) (index % bufferSize)); + vectors.get(result); + return result; + } /** @return Vector difference from v1 to v2 */ private double[] getDifference(double[] v1, double[] v2) { diff --git a/src/main/java/com/medallia/word2vec/Word2VecModel.java b/src/main/java/com/medallia/word2vec/Word2VecModel.java index 5fa6b25..6299241 100644 --- a/src/main/java/com/medallia/word2vec/Word2VecModel.java +++ b/src/main/java/com/medallia/word2vec/Word2VecModel.java @@ -39,17 +39,26 @@ public class Word2VecModel { final List vocab; final int layerSize; - final DoubleBuffer vectors; + /** The max number of vectors stored in each DoubleBuffer. */ + final int vectorsPerBuffer; + final DoubleBuffer[] vectors; private final static long ONE_GB = 1024 * 1024 * 1024; + /** + * The maxiumum size we will build a double buffer, in doubles. Since we use + * memory-mapped byte buffers, and these have their size specified with an + * int, the most doubles we can store is Integer.MAX_VALUE / 8. + */ + private final static int MAX_DOUBLE_BUFFER = Integer.MAX_VALUE / 8; - Word2VecModel(Iterable vocab, int layerSize, DoubleBuffer vectors) { + Word2VecModel(Iterable vocab, int layerSize, DoubleBuffer[] vectors) { this.vocab = ImmutableList.copyOf(vocab); this.layerSize = layerSize; this.vectors = vectors; + this.vectorsPerBuffer = vectors[0].limit() / layerSize; } Word2VecModel(Iterable vocab, int layerSize, double[] vectors) { - this(vocab, layerSize, DoubleBuffer.wrap(vectors)); + this(vocab, layerSize, new DoubleBuffer[] { DoubleBuffer.wrap(vectors) }); } /** @return Vocabulary */ @@ -65,12 +74,21 @@ public Searcher forSearch() { /** @return Serializable thrift representation */ public Word2VecModelThrift toThrift() { double[] vectorsArray; - if(vectors.hasArray()) { - vectorsArray = vectors.array(); + if(vectors.length == 1 && vectors[0].hasArray()) { + vectorsArray = vectors[0].array(); } else { - vectorsArray = new double[vectors.limit()]; - vectors.position(0); - vectors.get(vectorsArray); + int totalSize = 0; + for (DoubleBuffer buffer : vectors) { + totalSize += buffer.limit(); + } + vectorsArray = new double[totalSize]; + int copiedCount = 0; + for (DoubleBuffer buffer : vectors) { + int size = buffer.limit(); + buffer.position(0); + buffer.get(vectorsArray, copiedCount, size); + copiedCount += size; + } } return new Word2VecModelThrift() @@ -119,6 +137,16 @@ public static Word2VecModel fromBinFile(File file, ByteOrder byteOrder) */ public static Word2VecModel fromBinFile(File file, ByteOrder byteOrder, ProfilingTimer timer) throws IOException { + return fromBinFile(file, byteOrder, timer, MAX_DOUBLE_BUFFER); + } + + /** + * Testable version, with injected max double buffer size. + * @return {@link Word2VecModel} created from the binary representation output + * by the open source C version of word2vec using the given byte order. + */ + public static Word2VecModel fromBinFile(File file, ByteOrder byteOrder, ProfilingTimer timer, int maxDoubleBufferSize) + throws IOException { try ( final FileInputStream fis = new FileInputStream(file); @@ -156,58 +184,72 @@ public static Word2VecModel fromBinFile(File file, ByteOrder byteOrder, Profilin vocabSize, layerSize)); - List vocabs = new ArrayList(vocabSize); - DoubleBuffer vectors = ByteBuffer.allocateDirect(vocabSize * layerSize * 8).asDoubleBuffer(); + // Build up enough DoubleBuffers to store all of the vectors we'll be loading. + int vectorsPerBuffer = maxDoubleBufferSize / layerSize; + int numBuffers = vocabSize / vectorsPerBuffer + (vocabSize % vectorsPerBuffer != 0 ? 1 : 0); + DoubleBuffer[] vectors = new DoubleBuffer[numBuffers]; + int remainingVectors = vocabSize; + for (int i = 0; remainingVectors > vectorsPerBuffer; i++, remainingVectors -= vectorsPerBuffer) { + vectors[i] = ByteBuffer.allocateDirect(vectorsPerBuffer * layerSize * 8).asDoubleBuffer(); + } + if (remainingVectors > 0) { + vectors[numBuffers - 1] = ByteBuffer.allocateDirect(remainingVectors * layerSize * 8).asDoubleBuffer(); + } + List vocabs = new ArrayList(vocabSize); long lastLogMessage = System.currentTimeMillis(); final float[] floats = new float[layerSize]; - for (int lineno = 0; lineno < vocabSize; lineno++) { - // read vocab - sb.setLength(0); - c = (char) buffer.get(); - while (c != ' ') { - // ignore newlines in front of words (some binary files have newline, - // some don't) - if (c != '\n') { - sb.append(c); - } + int lineno = 0; + for (int buffno = 0; buffno < vectors.length; buffno++) { + DoubleBuffer vectorBuffer = vectors[buffno]; + for (int vecno = 0; vecno < Math.min(vectorsPerBuffer, vectorBuffer.limit()/ layerSize); vecno++, lineno++) { + // read vocab + sb.setLength(0); c = (char) buffer.get(); - } - vocabs.add(sb.toString()); + while (c != ' ') { + // ignore newlines in front of words (some binary files have newline, + // some don't) + if (c != '\n') { + sb.append(c); + } + c = (char) buffer.get(); + } + vocabs.add(sb.toString()); - // read vector - final FloatBuffer floatBuffer = buffer.asFloatBuffer(); - floatBuffer.get(floats); - for (int i = 0; i < floats.length; ++i) { - vectors.put(lineno * layerSize + i, floats[i]); - } - buffer.position(buffer.position() + 4 * layerSize); - - // print log - final long now = System.currentTimeMillis(); - if (now - lastLogMessage > 1000) { - final double percentage = ((double) (lineno + 1) / (double) vocabSize) * 100.0; - timer.appendToLog( - String.format("Loaded %d/%d vectors (%f%%)", lineno + 1, vocabSize, percentage)); - lastLogMessage = now; - } + // read vector + final FloatBuffer floatBuffer = buffer.asFloatBuffer(); + floatBuffer.get(floats); + for (int i = 0; i < floats.length; ++i) { + vectorBuffer.put(vecno * layerSize + i, floats[i]); + } + buffer.position(buffer.position() + 4 * layerSize); + + // print log + final long now = System.currentTimeMillis(); + if (now - lastLogMessage > 1000) { + final double percentage = ((double) (lineno + 1) / (double) vocabSize) * 100.0; + timer.appendToLog( + String.format("Loaded %d/%d vectors (%f%%)", lineno + 1, vocabSize, percentage)); + lastLogMessage = now; + } - // remap file - if (buffer.position() > ONE_GB) { - final int newPosition = (int) (buffer.position() - ONE_GB); - final long size = Math.min(channel.size() - ONE_GB * bufferCount, Integer.MAX_VALUE); - timer.endAndStart( - "Reading gigabyte #%d. Start: %d, size: %d", - bufferCount, - ONE_GB * bufferCount, - size); - buffer = channel.map( - FileChannel.MapMode.READ_ONLY, - ONE_GB * bufferCount, - size); - buffer.order(byteOrder); - buffer.position(newPosition); - bufferCount += 1; + // remap file + if (buffer.position() > ONE_GB) { + final int newPosition = (int) (buffer.position() - ONE_GB); + final long size = Math.min(channel.size() - ONE_GB * bufferCount, Integer.MAX_VALUE); + timer.endAndStart( + "Reading gigabyte #%d. Start: %d, size: %d", + bufferCount, + ONE_GB * bufferCount, + size); + buffer = channel.map( + FileChannel.MapMode.READ_ONLY, + ONE_GB * bufferCount, + size); + buffer.order(byteOrder); + buffer.position(newPosition); + bufferCount += 1; + } } } timer.end(); @@ -227,11 +269,15 @@ public void toBinFile(final OutputStream out) throws IOException { final double[] vector = new double[layerSize]; final ByteBuffer buffer = ByteBuffer.allocate(4 * layerSize); buffer.order(ByteOrder.LITTLE_ENDIAN); // The C version uses this byte order. + + int vectorsPerBuffer = MAX_DOUBLE_BUFFER / layerSize; + for(int i = 0; i < vocab.size(); ++i) { out.write(String.format("%s ", vocab.get(i)).getBytes(cs)); - vectors.position(i * layerSize); - vectors.get(vector); + DoubleBuffer vectorBuffer = vectors[i / vectorsPerBuffer]; + vectorBuffer.position(i * layerSize); + vectorBuffer.get(vector); buffer.clear(); for(int j = 0; j < layerSize; ++j) buffer.putFloat((float)vector[j]); diff --git a/src/test/java/com/medallia/word2vec/Word2VecBinTest.java b/src/test/java/com/medallia/word2vec/Word2VecBinTest.java index 9964fa6..44a9e27 100644 --- a/src/test/java/com/medallia/word2vec/Word2VecBinTest.java +++ b/src/test/java/com/medallia/word2vec/Word2VecBinTest.java @@ -6,6 +6,7 @@ import java.io.File; import java.io.IOException; import java.io.OutputStream; +import java.nio.ByteOrder; import java.nio.file.Files; import java.nio.file.Path; import java.util.List; @@ -16,6 +17,7 @@ import com.medallia.word2vec.Searcher.UnknownWordException; import com.medallia.word2vec.util.Common; +import com.medallia.word2vec.util.ProfilingTimer; /** * Tests converting the binary models into @@ -68,6 +70,27 @@ public void testRoundTrip() throws IOException, UnknownWordException { assertEquals(model, modelCopy); } + /** + * Tests that a Word2VecModel can load properly if the file doesn't fit into one buffer. + */ + @Test + public void testLargeFile() throws IOException, UnknownWordException { + File binFile = Common.getResourceAsFile( + this.getClass(), + "/com/medallia/word2vec/tokensModel.bin"); + // The tokens model has 1186 words of 200 doubles each. Make it store 500 vectors per buffer. + Word2VecModel binModel = + Word2VecModel.fromBinFile(binFile, ByteOrder.LITTLE_ENDIAN, ProfilingTimer.NONE, 500 * 200); + Assert.assertEquals("binary file should have been split in 3", binModel.vectors.length, 3); + + File txtFile = Common.getResourceAsFile( + this.getClass(), + "/com/medallia/word2vec/tokensModel.txt"); + Word2VecModel txtModel = Word2VecModel.fromTextFile(txtFile); + + assertEquals(binModel, txtModel); + } + @After public void cleanupTempFile() throws IOException { if(tempFile != null)