Skip to content

Commit 181b075

Browse files
author
Bartlomiej Alberski
committed
SPARK-11553 Primitive Row accessors throw NPE for null
1 parent 2d2411f commit 181b075

3 files changed

Lines changed: 65 additions & 23 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -191,39 +191,39 @@ trait Row extends Serializable {
191191
* @throws ClassCastException when data type does not match.
192192
* @throws NullPointerException when value is null.
193193
*/
194-
def getBoolean(i: Int): Boolean = getAs[Boolean](i)
194+
def getBoolean(i: Int): Boolean = getAnyValAs[Boolean](i)
195195

196196
/**
197197
* Returns the value at position i as a primitive byte.
198198
*
199199
* @throws ClassCastException when data type does not match.
200200
* @throws NullPointerException when value is null.
201201
*/
202-
def getByte(i: Int): Byte = getAs[Byte](i)
202+
def getByte(i: Int): Byte = getAnyValAs[Byte](i)
203203

204204
/**
205205
* Returns the value at position i as a primitive short.
206206
*
207207
* @throws ClassCastException when data type does not match.
208208
* @throws NullPointerException when value is null.
209209
*/
210-
def getShort(i: Int): Short = getAs[Short](i)
210+
def getShort(i: Int): Short = getAnyValAs[Short](i)
211211

212212
/**
213213
* Returns the value at position i as a primitive int.
214214
*
215215
* @throws ClassCastException when data type does not match.
216216
* @throws NullPointerException when value is null.
217217
*/
218-
def getInt(i: Int): Int = getAs[Int](i)
218+
def getInt(i: Int): Int = getAnyValAs[Int](i)
219219

220220
/**
221221
* Returns the value at position i as a primitive long.
222222
*
223223
* @throws ClassCastException when data type does not match.
224224
* @throws NullPointerException when value is null.
225225
*/
226-
def getLong(i: Int): Long = getAs[Long](i)
226+
def getLong(i: Int): Long = getAnyValAs[Long](i)
227227

228228
/**
229229
* Returns the value at position i as a primitive float.
@@ -232,21 +232,20 @@ trait Row extends Serializable {
232232
* @throws ClassCastException when data type does not match.
233233
* @throws NullPointerException when value is null.
234234
*/
235-
def getFloat(i: Int): Float = getAs[Float](i)
235+
def getFloat(i: Int): Float = getAnyValAs[Float](i)
236236

237237
/**
238238
* Returns the value at position i as a primitive double.
239239
*
240240
* @throws ClassCastException when data type does not match.
241241
* @throws NullPointerException when value is null.
242242
*/
243-
def getDouble(i: Int): Double = getAs[Double](i)
243+
def getDouble(i: Int): Double = getAnyValAs[Double](i)
244244

245245
/**
246246
* Returns the value at position i as a String object.
247247
*
248248
* @throws ClassCastException when data type does not match.
249-
* @throws NullPointerException when value is null.
250249
*/
251250
def getString(i: Int): String = getAs[String](i)
252251

@@ -310,13 +309,17 @@ trait Row extends Serializable {
310309

311310
/**
312311
* Returns the value at position i.
312+
* For primitive types if value is null it returns 'zero value' specific for primitive
313+
* ie. 0 for Int - use isNullAt to ensure that value is not null
313314
*
314315
* @throws ClassCastException when data type does not match.
315316
*/
316317
def getAs[T](i: Int): T = get(i).asInstanceOf[T]
317318

318319
/**
319320
* Returns the value of a given fieldName.
321+
* For primitive types if value is null it returns 'zero value' specific for primitive
322+
* ie. 0 for Int - use isNullAt to ensure that value is not null
320323
*
321324
* @throws UnsupportedOperationException when schema is not defined.
322325
* @throws IllegalArgumentException when fieldName do not exist.
@@ -336,6 +339,8 @@ trait Row extends Serializable {
336339

337340
/**
338341
* Returns a Map(name -> value) for the requested fieldNames
342+
* For primitive types if value is null it returns 'zero value' specific for primitive
343+
* ie. 0 for Int - use isNullAt to ensure that value is not null
339344
*
340345
* @throws UnsupportedOperationException when schema is not defined.
341346
* @throws IllegalArgumentException when fieldName do not exist.
@@ -450,4 +455,15 @@ trait Row extends Serializable {
450455
* start, end, and separator strings.
451456
*/
452457
def mkString(start: String, sep: String, end: String): String = toSeq.mkString(start, sep, end)
458+
459+
/**
460+
* Returns the value of a given fieldName.
461+
*
462+
* @throws UnsupportedOperationException when schema is not defined.
463+
* @throws ClassCastException when data type does not match.
464+
* @throws NullPointerException when value is null.
465+
*/
466+
private def getAnyValAs[T <: AnyVal](i: Int): T =
467+
if (isNullAt(i)) throw new NullPointerException(s"Value at index $i in null")
468+
else getAs[T](i)
453469
}

sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,10 @@ class RowTest extends FunSpec with Matchers {
2929
StructField("col2", StringType) ::
3030
StructField("col3", IntegerType) :: Nil)
3131
val values = Array("value1", "value2", 1)
32+
val valuesWithoutCol3 = Array[Any](null, "value2", null)
3233

3334
val sampleRow: Row = new GenericRowWithSchema(values, schema)
35+
val sampleRowWithoutCol3: Row = new GenericRowWithSchema(valuesWithoutCol3, schema)
3436
val noSchemaRow: Row = new GenericRow(values)
3537

3638
describe("Row (without schema)") {
@@ -68,6 +70,24 @@ class RowTest extends FunSpec with Matchers {
6870
)
6971
sampleRow.getValuesMap(List("col1", "col2")) shouldBe expected
7072
}
73+
74+
it("getValuesMap() retrieves null value on non AnyVal Type") {
75+
val expected = Map(
76+
"col1" -> null,
77+
"col2" -> "value2"
78+
)
79+
sampleRowWithoutCol3.getValuesMap[String](List("col1", "col2")) shouldBe expected
80+
}
81+
82+
it("getAs() on type extending AnyVal throws an exception when accessing field that is null") {
83+
intercept[NullPointerException] {
84+
sampleRowWithoutCol3.getInt(sampleRowWithoutCol3.fieldIndex("col3"))
85+
}
86+
}
87+
88+
it("getAs() on type extending AnyVal does not throw exception when value is null"){
89+
sampleRowWithoutCol3.getAs[String](sampleRowWithoutCol3.fieldIndex("col1")) shouldBe null
90+
}
7191
}
7292

7393
describe("row equals") {

sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,14 @@ class NestedLoopJoinNodeSuite extends LocalNodeTest {
5858
val hashJoinNode = makeUnsafeNode(leftNode, rightNode)
5959
val expectedOutput = generateExpectedOutput(leftInput, rightInput, joinType)
6060
val actualOutput = hashJoinNode.collect().map { row =>
61-
// (id, name, id, nickname)
62-
(row.getInt(0), row.getString(1), row.getInt(2), row.getString(3))
61+
// (
62+
// id, name,
63+
// id, nickname
64+
// )
65+
(
66+
Option(row.get(0)).map(_.asInstanceOf[Int]), Option(row.getString(1)),
67+
Option(row.get(2)).map(_.asInstanceOf[Int]), Option(row.getString(3))
68+
)
6369
}
6470
assert(actualOutput.toSet === expectedOutput.toSet)
6571
}
@@ -95,36 +101,36 @@ class NestedLoopJoinNodeSuite extends LocalNodeTest {
95101
private def generateExpectedOutput(
96102
leftInput: Array[(Int, String)],
97103
rightInput: Array[(Int, String)],
98-
joinType: JoinType): Array[(Int, String, Int, String)] = {
104+
joinType: JoinType): Array[(Option[Int], Option[String], Option[Int], Option[String])] = {
99105
joinType match {
100106
case LeftOuter =>
101107
val rightInputMap = rightInput.toMap
102108
leftInput.map { case (k, v) =>
103-
val rightKey = rightInputMap.get(k).map { _ => k }.getOrElse(0)
104-
val rightValue = rightInputMap.getOrElse(k, null)
105-
(k, v, rightKey, rightValue)
109+
val rightKey = rightInputMap.get(k).map { _ => k }
110+
val rightValue = rightInputMap.get(k)
111+
(Some(k), Some(v), rightKey, rightValue)
106112
}
107113

108114
case RightOuter =>
109115
val leftInputMap = leftInput.toMap
110116
rightInput.map { case (k, v) =>
111-
val leftKey = leftInputMap.get(k).map { _ => k }.getOrElse(0)
112-
val leftValue = leftInputMap.getOrElse(k, null)
113-
(leftKey, leftValue, k, v)
117+
val leftKey = leftInputMap.get(k).map { _ => k }
118+
val leftValue = leftInputMap.get(k)
119+
(leftKey, leftValue, Some(k), Some(v))
114120
}
115121

116122
case FullOuter =>
117123
val leftInputMap = leftInput.toMap
118124
val rightInputMap = rightInput.toMap
119125
val leftOutput = leftInput.map { case (k, v) =>
120-
val rightKey = rightInputMap.get(k).map { _ => k }.getOrElse(0)
121-
val rightValue = rightInputMap.getOrElse(k, null)
122-
(k, v, rightKey, rightValue)
126+
val rightKey = rightInputMap.get(k).map { _ => k }
127+
val rightValue = rightInputMap.get(k)
128+
(Some(k), Some(v), rightKey, rightValue)
123129
}
124130
val rightOutput = rightInput.map { case (k, v) =>
125-
val leftKey = leftInputMap.get(k).map { _ => k }.getOrElse(0)
126-
val leftValue = leftInputMap.getOrElse(k, null)
127-
(leftKey, leftValue, k, v)
131+
val leftKey = leftInputMap.get(k).map { _ => k }
132+
val leftValue = leftInputMap.get(k)
133+
(leftKey, leftValue, Some(k), Some(v))
128134
}
129135
(leftOutput ++ rightOutput).distinct
130136

0 commit comments

Comments
 (0)