Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
Original file line number Diff line number Diff line change
Expand Up @@ -191,39 +191,39 @@ trait Row extends Serializable {
* @throws ClassCastException when data type does not match.
* @throws NullPointerException when value is null.
*/
def getBoolean(i: Int): Boolean = getAs[Boolean](i)
def getBoolean(i: Int): Boolean = getAnyValAs[Boolean](i)

/**
* Returns the value at position i as a primitive byte.
*
* @throws ClassCastException when data type does not match.
* @throws NullPointerException when value is null.
*/
def getByte(i: Int): Byte = getAs[Byte](i)
def getByte(i: Int): Byte = getAnyValAs[Byte](i)

/**
* Returns the value at position i as a primitive short.
*
* @throws ClassCastException when data type does not match.
* @throws NullPointerException when value is null.
*/
def getShort(i: Int): Short = getAs[Short](i)
def getShort(i: Int): Short = getAnyValAs[Short](i)

/**
* Returns the value at position i as a primitive int.
*
* @throws ClassCastException when data type does not match.
* @throws NullPointerException when value is null.
*/
def getInt(i: Int): Int = getAs[Int](i)
def getInt(i: Int): Int = getAnyValAs[Int](i)

/**
* Returns the value at position i as a primitive long.
*
* @throws ClassCastException when data type does not match.
* @throws NullPointerException when value is null.
*/
def getLong(i: Int): Long = getAs[Long](i)
def getLong(i: Int): Long = getAnyValAs[Long](i)

/**
* Returns the value at position i as a primitive float.
Expand All @@ -232,21 +232,20 @@ trait Row extends Serializable {
* @throws ClassCastException when data type does not match.
* @throws NullPointerException when value is null.
*/
def getFloat(i: Int): Float = getAs[Float](i)
def getFloat(i: Int): Float = getAnyValAs[Float](i)

/**
* Returns the value at position i as a primitive double.
*
* @throws ClassCastException when data type does not match.
* @throws NullPointerException when value is null.
*/
def getDouble(i: Int): Double = getAs[Double](i)
def getDouble(i: Int): Double = getAnyValAs[Double](i)

/**
* Returns the value at position i as a String object.
*
* @throws ClassCastException when data type does not match.
* @throws NullPointerException when value is null.
*/
def getString(i: Int): String = getAs[String](i)

Expand Down Expand Up @@ -310,13 +309,17 @@ trait Row extends Serializable {

/**
* Returns the value at position i.
* For primitive types if value is null it returns 'zero value' specific for primitive
* ie. 0 for Int - use isNullAt to ensure that value is not null
*
* @throws ClassCastException when data type does not match.
*/
def getAs[T](i: Int): T = get(i).asInstanceOf[T]

/**
* Returns the value of a given fieldName.
* For primitive types if value is null it returns 'zero value' specific for primitive
* ie. 0 for Int - use isNullAt to ensure that value is not null
*
* @throws UnsupportedOperationException when schema is not defined.
* @throws IllegalArgumentException when fieldName do not exist.
Expand All @@ -336,6 +339,8 @@ trait Row extends Serializable {

/**
* Returns a Map(name -> value) for the requested fieldNames
* For primitive types if value is null it returns 'zero value' specific for primitive
* ie. 0 for Int - use isNullAt to ensure that value is not null
*
* @throws UnsupportedOperationException when schema is not defined.
* @throws IllegalArgumentException when fieldName do not exist.
Expand Down Expand Up @@ -450,4 +455,15 @@ trait Row extends Serializable {
* start, end, and separator strings.
*/
def mkString(start: String, sep: String, end: String): String = toSeq.mkString(start, sep, end)

/**
* Returns the value of a given fieldName.
*
* @throws UnsupportedOperationException when schema is not defined.
* @throws ClassCastException when data type does not match.
* @throws NullPointerException when value is null.
*/
private def getAnyValAs[T <: AnyVal](i: Int): T =
if (isNullAt(i)) throw new NullPointerException(s"Value at index $i in null")
else getAs[T](i)
}
20 changes: 20 additions & 0 deletions sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@ class RowTest extends FunSpec with Matchers {
StructField("col2", StringType) ::
StructField("col3", IntegerType) :: Nil)
val values = Array("value1", "value2", 1)
val valuesWithoutCol3 = Array[Any](null, "value2", null)

val sampleRow: Row = new GenericRowWithSchema(values, schema)
val sampleRowWithoutCol3: Row = new GenericRowWithSchema(valuesWithoutCol3, schema)
val noSchemaRow: Row = new GenericRow(values)

describe("Row (without schema)") {
Expand Down Expand Up @@ -68,6 +70,24 @@ class RowTest extends FunSpec with Matchers {
)
sampleRow.getValuesMap(List("col1", "col2")) shouldBe expected
}

it("getValuesMap() retrieves null value on non AnyVal Type") {
val expected = Map(
"col1" -> null,
"col2" -> "value2"
)
sampleRowWithoutCol3.getValuesMap[String](List("col1", "col2")) shouldBe expected
}

it("getAs() on type extending AnyVal throws an exception when accessing field that is null") {
intercept[NullPointerException] {
sampleRowWithoutCol3.getInt(sampleRowWithoutCol3.fieldIndex("col3"))
}
}

it("getAs() on type extending AnyVal does not throw exception when value is null"){
sampleRowWithoutCol3.getAs[String](sampleRowWithoutCol3.fieldIndex("col1")) shouldBe null
}
}

describe("row equals") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,14 @@ class NestedLoopJoinNodeSuite extends LocalNodeTest {
val hashJoinNode = makeUnsafeNode(leftNode, rightNode)
val expectedOutput = generateExpectedOutput(leftInput, rightInput, joinType)
val actualOutput = hashJoinNode.collect().map { row =>
// (id, name, id, nickname)
(row.getInt(0), row.getString(1), row.getInt(2), row.getString(3))
// (
// id, name,
// id, nickname
// )
(
Option(row.get(0)).map(_.asInstanceOf[Int]), Option(row.getString(1)),
Option(row.get(2)).map(_.asInstanceOf[Int]), Option(row.getString(3))
)
}
assert(actualOutput.toSet === expectedOutput.toSet)
}
Expand Down Expand Up @@ -95,36 +101,36 @@ class NestedLoopJoinNodeSuite extends LocalNodeTest {
private def generateExpectedOutput(
leftInput: Array[(Int, String)],
rightInput: Array[(Int, String)],
joinType: JoinType): Array[(Int, String, Int, String)] = {
joinType: JoinType): Array[(Option[Int], Option[String], Option[Int], Option[String])] = {
joinType match {
case LeftOuter =>
val rightInputMap = rightInput.toMap
leftInput.map { case (k, v) =>
val rightKey = rightInputMap.get(k).map { _ => k }.getOrElse(0)
val rightValue = rightInputMap.getOrElse(k, null)
(k, v, rightKey, rightValue)
val rightKey = rightInputMap.get(k).map { _ => k }
val rightValue = rightInputMap.get(k)
(Some(k), Some(v), rightKey, rightValue)
}

case RightOuter =>
val leftInputMap = leftInput.toMap
rightInput.map { case (k, v) =>
val leftKey = leftInputMap.get(k).map { _ => k }.getOrElse(0)
val leftValue = leftInputMap.getOrElse(k, null)
(leftKey, leftValue, k, v)
val leftKey = leftInputMap.get(k).map { _ => k }
val leftValue = leftInputMap.get(k)
(leftKey, leftValue, Some(k), Some(v))
}

case FullOuter =>
val leftInputMap = leftInput.toMap
val rightInputMap = rightInput.toMap
val leftOutput = leftInput.map { case (k, v) =>
val rightKey = rightInputMap.get(k).map { _ => k }.getOrElse(0)
val rightValue = rightInputMap.getOrElse(k, null)
(k, v, rightKey, rightValue)
val rightKey = rightInputMap.get(k).map { _ => k }
val rightValue = rightInputMap.get(k)
(Some(k), Some(v), rightKey, rightValue)
}
val rightOutput = rightInput.map { case (k, v) =>
val leftKey = leftInputMap.get(k).map { _ => k }.getOrElse(0)
val leftValue = leftInputMap.getOrElse(k, null)
(leftKey, leftValue, k, v)
val leftKey = leftInputMap.get(k).map { _ => k }
val leftValue = leftInputMap.get(k)
(leftKey, leftValue, Some(k), Some(v))
}
(leftOutput ++ rightOutput).distinct

Expand Down