Skip to content

Commit ff318c0

Browse files
committed
[SPARK-21050][ML] Word2vec persistence overflow bug fix
## What changes were proposed in this pull request? The method calculateNumberOfPartitions() uses Int, not Long (unlike the MLlib version), so it is very easily to have an overflow in calculating the number of partitions for ML persistence. This modifies the calculations to use Long. ## How was this patch tested? New unit test. I verified that the test fails before this patch. Author: Joseph K. Bradley <joseph@databricks.com> Closes #18265 from jkbradley/word2vec-save-fix.
1 parent b1436c7 commit ff318c0

2 files changed

Lines changed: 38 additions & 10 deletions

File tree

mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.ml.feature
1919

2020
import org.apache.hadoop.fs.Path
2121

22+
import org.apache.spark.SparkContext
2223
import org.apache.spark.annotation.Since
2324
import org.apache.spark.ml.{Estimator, Model}
2425
import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors, VectorUDT}
@@ -339,25 +340,42 @@ object Word2VecModel extends MLReadable[Word2VecModel] {
339340
val wordVectors = instance.wordVectors.getVectors
340341
val dataSeq = wordVectors.toSeq.map { case (word, vector) => Data(word, vector) }
341342
val dataPath = new Path(path, "data").toString
343+
val bufferSizeInBytes = Utils.byteStringAsBytes(
344+
sc.conf.get("spark.kryoserializer.buffer.max", "64m"))
345+
val numPartitions = Word2VecModelWriter.calculateNumberOfPartitions(
346+
bufferSizeInBytes, instance.wordVectors.wordIndex.size, instance.getVectorSize)
342347
sparkSession.createDataFrame(dataSeq)
343-
.repartition(calculateNumberOfPartitions)
348+
.repartition(numPartitions)
344349
.write
345350
.parquet(dataPath)
346351
}
352+
}
347353

348-
def calculateNumberOfPartitions(): Int = {
349-
val floatSize = 4
354+
private[feature]
355+
object Word2VecModelWriter {
356+
/**
357+
* Calculate the number of partitions to use in saving the model.
358+
* [SPARK-11994] - We want to partition the model in partitions smaller than
359+
* spark.kryoserializer.buffer.max
360+
* @param bufferSizeInBytes Set to spark.kryoserializer.buffer.max
361+
* @param numWords Vocab size
362+
* @param vectorSize Vector length for each word
363+
*/
364+
def calculateNumberOfPartitions(
365+
bufferSizeInBytes: Long,
366+
numWords: Int,
367+
vectorSize: Int): Int = {
368+
val floatSize = 4L // Use Long to help avoid overflow
350369
val averageWordSize = 15
351-
// [SPARK-11994] - We want to partition the model in partitions smaller than
352-
// spark.kryoserializer.buffer.max
353-
val bufferSizeInBytes = Utils.byteStringAsBytes(
354-
sc.conf.get("spark.kryoserializer.buffer.max", "64m"))
355370
// Calculate the approximate size of the model.
356371
// Assuming an average word size of 15 bytes, the formula is:
357372
// (floatSize * vectorSize + 15) * numWords
358-
val numWords = instance.wordVectors.wordIndex.size
359-
val approximateSizeInBytes = (floatSize * instance.getVectorSize + averageWordSize) * numWords
360-
((approximateSizeInBytes / bufferSizeInBytes) + 1).toInt
373+
val approximateSizeInBytes = (floatSize * vectorSize + averageWordSize) * numWords
374+
val numPartitions = (approximateSizeInBytes / bufferSizeInBytes) + 1
375+
require(numPartitions < 10e8, s"Word2VecModel calculated that it needs $numPartitions " +
376+
s"partitions to save this model, which is too large. Try increasing " +
377+
s"spark.kryoserializer.buffer.max so that Word2VecModel can use fewer partitions.")
378+
numPartitions.toInt
361379
}
362380
}
363381

mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import org.apache.spark.ml.util.TestingUtils._
2525
import org.apache.spark.mllib.feature.{Word2VecModel => OldWord2VecModel}
2626
import org.apache.spark.mllib.util.MLlibTestSparkContext
2727
import org.apache.spark.sql.Row
28+
import org.apache.spark.util.Utils
2829

2930
class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
3031

@@ -188,6 +189,15 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
188189
assert(math.abs(similarity(5) - similarityLarger(5) / similarity(5)) > 1E-5)
189190
}
190191

192+
test("Word2Vec read/write numPartitions calculation") {
193+
val smallModelNumPartitions = Word2VecModel.Word2VecModelWriter.calculateNumberOfPartitions(
194+
Utils.byteStringAsBytes("64m"), numWords = 10, vectorSize = 5)
195+
assert(smallModelNumPartitions === 1)
196+
val largeModelNumPartitions = Word2VecModel.Word2VecModelWriter.calculateNumberOfPartitions(
197+
Utils.byteStringAsBytes("64m"), numWords = 1000000, vectorSize = 5000)
198+
assert(largeModelNumPartitions > 1)
199+
}
200+
191201
test("Word2Vec read/write") {
192202
val t = new Word2Vec()
193203
.setInputCol("myInputCol")

0 commit comments

Comments
 (0)