Skip to content
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}

Expand Down Expand Up @@ -333,31 +333,12 @@ object ScalaReflection extends ScalaReflection {
// TODO: add walked type path for map
val TypeRef(_, _, Seq(keyType, valueType)) = t

val keyData =
Invoke(
MapObjects(
p => deserializerFor(keyType, Some(p), walkedTypePath),
Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType),
returnNullable = false),
schemaFor(keyType).dataType),
"array",
ObjectType(classOf[Array[Any]]), returnNullable = false)

val valueData =
Invoke(
MapObjects(
p => deserializerFor(valueType, Some(p), walkedTypePath),
Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType),
returnNullable = false),
schemaFor(valueType).dataType),
"array",
ObjectType(classOf[Array[Any]]), returnNullable = false)

StaticInvoke(
ArrayBasedMapData.getClass,
ObjectType(classOf[scala.collection.immutable.Map[_, _]]),
"toScalaMap",
keyData :: valueData :: Nil)
CollectObjectsToMap(
p => deserializerFor(keyType, Some(p), walkedTypePath),
p => deserializerFor(valueType, Some(p), walkedTypePath),
getPath,
mirror.runtimeClass(t.typeSymbol.asClass)
)

case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) =>
val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,177 @@ case class MapObjects private(
}
}

object CollectObjectsToMap {
private val curId = new java.util.concurrent.atomic.AtomicInteger()

/**
* Construct an instance of CollectObjectsToMap case class.
*
* @param keyFunction The function applied on the key collection elements.
* @param valueFunction The function applied on the value collection elements.
* @param inputData An expression that when evaluated returns a map object.
* @param collClass The type of the resulting collection.
*/
def apply(
keyFunction: Expression => Expression,
valueFunction: Expression => Expression,
inputData: Expression,
collClass: Class[_]): CollectObjectsToMap = {
val id = curId.getAndIncrement()
val keyLoopValue = s"CollectObjectsToMap_keyLoopValue$id"
val mapType = inputData.dataType.asInstanceOf[MapType]
val keyLoopVar = LambdaVariable(keyLoopValue, "", mapType.keyType, nullable = false)
val valueLoopValue = s"CollectObjectsToMap_valueLoopValue$id"
val valueLoopIsNull = s"CollectObjectsToMap_valueLoopIsNull$id"
val valueLoopVar = LambdaVariable(valueLoopValue, valueLoopIsNull, mapType.valueType)
CollectObjectsToMap(
keyLoopValue, keyFunction(keyLoopVar),
valueLoopValue, valueLoopIsNull, valueFunction(valueLoopVar),
inputData, collClass)
}
}

/**
* An equivalent to the [[MapObjects]] case class but returning an ObjectType containing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's update the class doc to explicitly say that this expression is used to convert a catalyst map to external map.

* a Scala collection constructed using the associated builder, obtained by calling `newBuilder`
* on the collection's companion object.
*
* @param keyLoopValue the name of the loop variable that is used when iterating over the key
* collection, and which is used as input for the `keyLambdaFunction`
* @param keyLambdaFunction A function that takes the `keyLoopVar` as input, and is used as
* a lambda function to handle collection elements.
* @param valueLoopValue the name of the loop variable that is used when iterating over the value
* collection, and which is used as input for the `valueLambdaFunction`
* @param valueLoopIsNull the nullability of the loop variable that is used when iterating over
* the value collection, and which is used as input for the
* `valueLambdaFunction`
* @param valueLambdaFunction A function that takes the `valueLoopVar` as input, and is used as
* a lambda function to handle collection elements.
* @param inputData An expression that when evaluated returns a map object.
* @param collClass The type of the resulting collection.
*/
case class CollectObjectsToMap private(
keyLoopValue: String,
keyLambdaFunction: Expression,
valueLoopValue: String,
valueLoopIsNull: String,
valueLambdaFunction: Expression,
inputData: Expression,
collClass: Class[_]) extends Expression with NonSQLExpression {

override def nullable: Boolean = inputData.nullable

override def children: Seq[Expression] =
keyLambdaFunction :: valueLambdaFunction :: inputData :: Nil

override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported")

override def dataType: DataType = ObjectType(collClass)

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val mapType = inputData.dataType.asInstanceOf[MapType]
val keyElementJavaType = ctx.javaType(mapType.keyType)
ctx.addMutableState(keyElementJavaType, keyLoopValue, "")
val genKeyFunction = keyLambdaFunction.genCode(ctx)
val valueElementJavaType = ctx.javaType(mapType.valueType)
ctx.addMutableState("boolean", valueLoopIsNull, "")
ctx.addMutableState(valueElementJavaType, valueLoopValue, "")
val genValueFunction = valueLambdaFunction.genCode(ctx)
val genInputData = inputData.genCode(ctx)
val dataLength = ctx.freshName("dataLength")
val loopIndex = ctx.freshName("loopIndex")
val tupleLoopValue = ctx.freshName("tupleLoopValue")
val builderValue = ctx.freshName("builderValue")

val keyArray = ctx.freshName("keyArray")
val valueArray = ctx.freshName("valueArray")

// The data with PythonUserDefinedType are actually stored with the data type of its sqlType.
// When we want to apply MapObjects on it, we have to use it.
def inputDataType(dataType: DataType) = dataType match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the code in MapObejcts is:

    val inputDataType = inputData.dataType match {
      case p: PythonUserDefinedType => p.sqlType
      case _ => inputData.dataType
    }

We should call this before we do val mapType = inputData.dataType.asInstanceOf[MapType]

case p: PythonUserDefinedType => p.sqlType
case _ => dataType
}

def lengthAndLoopVar(elementType: DataType, genInputData: ExprCode, method: String,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's just 2 lines method, can we inline it?

array: String) =
s"${genInputData.value}.$method().numElements()" ->
ctx.getValue(s"${genInputData.value}.$method()", elementType, loopIndex)

val ((getKeyLength, getKeyLoopVar), (getValueLength, getValueLoopVar)) = (
lengthAndLoopVar(inputDataType(mapType.keyType), genInputData, "keyArray", keyArray),
lengthAndLoopVar(inputDataType(mapType.valueType), genInputData, "valueArray", valueArray)
)

// Make a copy of the data if it's unsafe-backed
def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) =
s"$value instanceof ${clazz.getSimpleName}? $value.copy() : $value"
def genFunctionValue(lambdaFunction: Expression, genFunction: ExprCode) =
lambdaFunction.dataType match {
case StructType(_) => makeCopyIfInstanceOf(classOf[UnsafeRow], genFunction.value)
case ArrayType(_, _) => makeCopyIfInstanceOf(classOf[UnsafeArrayData], genFunction.value)
case MapType(_, _, _) => makeCopyIfInstanceOf(classOf[UnsafeMapData], genFunction.value)
case _ => genFunction.value
}
val genKeyFunctionValue = genFunctionValue(keyLambdaFunction, genKeyFunction)
val genValueFunctionValue = genFunctionValue(valueLambdaFunction, genValueFunction)

val valueLoopNullCheck =
s"$valueLoopIsNull = ${genInputData.value}.valueArray().isNullAt($loopIndex);"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about $valueArray.isNullAt($loopIndex)?


val builderClass = classOf[Builder[_, _]].getName
val constructBuilder = s"""
$builderClass $builderValue = ${collClass.getName}$$.MODULE$$.newBuilder();
$builderValue.sizeHint($dataLength);
"""

val tupleClass = classOf[(_, _)].getName
val appendToBuilder = s"""
$tupleClass $tupleLoopValue;

if (${genValueFunction.isNull}) {
$tupleLoopValue = new $tupleClass($genKeyFunctionValue, null);
} else {
$tupleLoopValue = new $tupleClass($genKeyFunctionValue, $genValueFunctionValue);
}

$builderValue.$$plus$$eq($tupleLoopValue);
"""
val getBuilderResult = s"${ev.value} = (${collClass.getName}) $builderValue.result();"

val code = s"""
${genInputData.code}
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};

if (!${genInputData.isNull}) {
if ($getKeyLength != $getValueLength) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't need a keyLength and valueLength, just have a mapLength which can be calculated by MapData.numElements

throw new RuntimeException("Invalid state: Inconsistent lengths of key-value arrays");
}
int $dataLength = $getKeyLength;
$constructBuilder

int $loopIndex = 0;
while ($loopIndex < $dataLength) {
$keyLoopValue = ($keyElementJavaType) ($getKeyLoopVar);
$valueLoopValue = ($valueElementJavaType) ($getValueLoopVar);
$valueLoopNullCheck
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can also inline this. The principle is, we should inline these simple codes as many as possible, then when you look at this code block, it's more clear what's going on.


${genKeyFunction.code}
${genValueFunction.code}

$appendToBuilder

$loopIndex += 1;
}

$getBuilderResult
}
"""
ev.copy(code = code, isNull = genInputData.isNull)
}
}

object ExternalMapToCatalyst {
private val curId = new java.util.concurrent.atomic.AtomicInteger()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,31 @@ class ScalaReflectionSuite extends SparkFunSuite {
assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]]))
}

test("serialize and deserialize arbitrary map types") {
val mapSerializer = serializerFor[Map[Int, Int]](BoundReference(
0, ObjectType(classOf[Map[Int, Int]]), nullable = false))
assert(mapSerializer.dataType.head.dataType ==
MapType(IntegerType, IntegerType, valueContainsNull = false))
val mapDeserializer = deserializerFor[Map[Int, Int]]
assert(mapDeserializer.dataType == ObjectType(classOf[Map[_, _]]))

import scala.collection.immutable.HashMap
val hashMapSerializer = serializerFor[HashMap[Int, Int]](BoundReference(
0, ObjectType(classOf[HashMap[Int, Int]]), nullable = false))
assert(hashMapSerializer.dataType.head.dataType ==
MapType(IntegerType, IntegerType, valueContainsNull = false))
val hashMapDeserializer = deserializerFor[HashMap[Int, Int]]
assert(hashMapDeserializer.dataType == ObjectType(classOf[HashMap[_, _]]))

import scala.collection.mutable.{LinkedHashMap => LHMap}
val linkedHashMapSerializer = serializerFor[LHMap[Long, String]](BoundReference(
0, ObjectType(classOf[LHMap[Long, String]]), nullable = false))
assert(linkedHashMapSerializer.dataType.head.dataType ==
MapType(LongType, StringType, valueContainsNull = true))
val linkedHashMapDeserializer = deserializerFor[LHMap[Long, String]]
assert(linkedHashMapDeserializer.dataType == ObjectType(classOf[LHMap[_, _]]))
}

private val dataTypeForComplexData = dataTypeFor[ComplexData]
private val typeOfComplexData = typeOf[ComplexData]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql

import scala.collection.Map
import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag

Expand Down Expand Up @@ -166,6 +167,10 @@ abstract class SQLImplicits extends LowPrioritySQLImplicits {
/** @since 2.2.0 */
implicit def newSequenceEncoder[T <: Seq[_] : TypeTag]: Encoder[T] = ExpressionEncoder()

// Maps
/** @since 2.3.0 */
implicit def newMapEncoder[T <: Map[_, _] : TypeTag]: Encoder[T] = ExpressionEncoder()

// Arrays

/** @since 1.6.1 */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql

import scala.collection.immutable.Queue
import scala.collection.mutable.{LinkedHashMap => LHMap}
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.test.SharedSQLContext
Expand All @@ -30,8 +31,14 @@ case class ListClass(l: List[Int])

case class QueueClass(q: Queue[Int])

case class MapClass(m: Map[Int, Int])

case class LHMapClass(m: LHMap[Int, Int])

case class ComplexClass(seq: SeqClass, list: ListClass, queue: QueueClass)

case class ComplexMapClass(map: MapClass, lhmap: LHMapClass)

package object packageobject {
case class PackageClass(value: Int)
}
Expand Down Expand Up @@ -258,6 +265,80 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
ListClass(List(1)) -> Queue("test" -> SeqClass(Seq(2))))
}

test("arbitrary maps") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this suite is DatasetPrimitiveSuite, we should move the list/seq/map tests to a new suite DatasetComplexTypeSuite

checkDataset(Seq(Map(1 -> 2)).toDS(), Map(1 -> 2))
checkDataset(Seq(Map(1.toLong -> 2.toLong)).toDS(), Map(1.toLong -> 2.toLong))
checkDataset(Seq(Map(1.toDouble -> 2.toDouble)).toDS(), Map(1.toDouble -> 2.toDouble))
checkDataset(Seq(Map(1.toFloat -> 2.toFloat)).toDS(), Map(1.toFloat -> 2.toFloat))
checkDataset(Seq(Map(1.toByte -> 2.toByte)).toDS(), Map(1.toByte -> 2.toByte))
checkDataset(Seq(Map(1.toShort -> 2.toShort)).toDS(), Map(1.toShort -> 2.toShort))
checkDataset(Seq(Map(true -> false)).toDS(), Map(true -> false))
checkDataset(Seq(Map("test1" -> "test2")).toDS(), Map("test1" -> "test2"))
checkDataset(Seq(Map(Tuple1(1) -> Tuple1(2))).toDS(), Map(Tuple1(1) -> Tuple1(2)))
checkDataset(Seq(Map(1 -> Tuple1(2))).toDS(), Map(1 -> Tuple1(2)))
checkDataset(Seq(Map("test" -> 2.toLong)).toDS(), Map("test" -> 2.toLong))

checkDataset(Seq(LHMap(1 -> 2)).toDS(), LHMap(1 -> 2))
checkDataset(Seq(LHMap(1.toLong -> 2.toLong)).toDS(), LHMap(1.toLong -> 2.toLong))
checkDataset(Seq(LHMap(1.toDouble -> 2.toDouble)).toDS(), LHMap(1.toDouble -> 2.toDouble))
checkDataset(Seq(LHMap(1.toFloat -> 2.toFloat)).toDS(), LHMap(1.toFloat -> 2.toFloat))
checkDataset(Seq(LHMap(1.toByte -> 2.toByte)).toDS(), LHMap(1.toByte -> 2.toByte))
checkDataset(Seq(LHMap(1.toShort -> 2.toShort)).toDS(), LHMap(1.toShort -> 2.toShort))
checkDataset(Seq(LHMap(true -> false)).toDS(), LHMap(true -> false))
checkDataset(Seq(LHMap("test1" -> "test2")).toDS(), LHMap("test1" -> "test2"))
checkDataset(Seq(LHMap(Tuple1(1) -> Tuple1(2))).toDS(), LHMap(Tuple1(1) -> Tuple1(2)))
checkDataset(Seq(LHMap(1 -> Tuple1(2))).toDS(), LHMap(1 -> Tuple1(2)))
checkDataset(Seq(LHMap("test" -> 2.toLong)).toDS(), LHMap("test" -> 2.toLong))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add some nested map cases? e.g. Map(1 -> LHMap(2 -> 3))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added as a separate test case (same as sequences)

}

ignore("SPARK-19104: map and product combinations") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why ignore?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added these tests for issue SPARK-19104 CompileException with Map and Case Class in Spark 2.1.0 as I thought I could fix it as part of this PR. However, I found out that it was a more complicated issue than I anticipated so I left the tests there and ignored them. I can remove them.

// Case classes
checkDataset(Seq(MapClass(Map(1 -> 2))).toDS(), MapClass(Map(1 -> 2)))
checkDataset(Seq(Map(1 -> MapClass(Map(2 -> 3)))).toDS(), Map(1 -> MapClass(Map(2 -> 3))))
checkDataset(Seq(Map(MapClass(Map(1 -> 2)) -> 3)).toDS(), Map(MapClass(Map(1 -> 2)) -> 3))
checkDataset(Seq(Map(MapClass(Map(1 -> 2)) -> MapClass(Map(3 -> 4)))).toDS(),
Map(MapClass(Map(1 -> 2)) -> MapClass(Map(3 -> 4))))
checkDataset(Seq(LHMap(1 -> MapClass(Map(2 -> 3)))).toDS(), LHMap(1 -> MapClass(Map(2 -> 3))))
checkDataset(Seq(LHMap(MapClass(Map(1 -> 2)) -> 3)).toDS(), LHMap(MapClass(Map(1 -> 2)) -> 3))
checkDataset(Seq(LHMap(MapClass(Map(1 -> 2)) -> MapClass(Map(3 -> 4)))).toDS(),
LHMap(MapClass(Map(1 -> 2)) -> MapClass(Map(3 -> 4))))

checkDataset(Seq(LHMapClass(LHMap(1 -> 2))).toDS(), LHMapClass(LHMap(1 -> 2)))
checkDataset(Seq(Map(1 -> LHMapClass(LHMap(2 -> 3)))).toDS(),
Map(1 -> LHMapClass(LHMap(2 -> 3))))
checkDataset(Seq(Map(LHMapClass(LHMap(1 -> 2)) -> 3)).toDS(),
Map(LHMapClass(LHMap(1 -> 2)) -> 3))
checkDataset(Seq(Map(LHMapClass(LHMap(1 -> 2)) -> LHMapClass(LHMap(3 -> 4)))).toDS(),
Map(LHMapClass(LHMap(1 -> 2)) -> LHMapClass(LHMap(3 -> 4))))
checkDataset(Seq(LHMap(1 -> LHMapClass(LHMap(2 -> 3)))).toDS(),
LHMap(1 -> LHMapClass(LHMap(2 -> 3))))
checkDataset(Seq(LHMap(LHMapClass(LHMap(1 -> 2)) -> 3)).toDS(),
LHMap(LHMapClass(LHMap(1 -> 2)) -> 3))
checkDataset(Seq(LHMap(LHMapClass(LHMap(1 -> 2)) -> LHMapClass(LHMap(3 -> 4)))).toDS(),
LHMap(LHMapClass(LHMap(1 -> 2)) -> LHMapClass(LHMap(3 -> 4))))

val complex = ComplexMapClass(MapClass(Map(1 -> 2)), LHMapClass(LHMap(3 -> 4)))
checkDataset(Seq(complex).toDS(), complex)
checkDataset(Seq(Map(1 -> complex)).toDS(), Map(1 -> complex))
checkDataset(Seq(Map(complex -> 5)).toDS(), Map(complex -> 5))
checkDataset(Seq(Map(complex -> complex)).toDS(), Map(complex -> complex))
checkDataset(Seq(LHMap(1 -> complex)).toDS(), LHMap(1 -> complex))
checkDataset(Seq(LHMap(complex -> 5)).toDS(), LHMap(complex -> 5))
checkDataset(Seq(LHMap(complex -> complex)).toDS(), LHMap(complex -> complex))

// Tuples
checkDataset(Seq(Map(1 -> 2) -> Map(3 -> 4)).toDS(), Map(1 -> 2) -> Map(3 -> 4))
checkDataset(Seq(LHMap(1 -> 2) -> Map(3 -> 4)).toDS(), LHMap(1 -> 2) -> Map(3 -> 4))
checkDataset(Seq(Map(1 -> 2) -> LHMap(3 -> 4)).toDS(), Map(1 -> 2) -> LHMap(3 -> 4))
checkDataset(Seq(LHMap(1 -> 2) -> LHMap(3 -> 4)).toDS(), LHMap(1 -> 2) -> LHMap(3 -> 4))
checkDataset(Seq(LHMap((Map("test1" -> 1) -> 2) -> (3 -> LHMap(4 -> "test2")))).toDS(),
LHMap((Map("test1" -> 1) -> 2) -> (3 -> LHMap(4 -> "test2"))))

// Complex
checkDataset(Seq(LHMapClass(LHMap(1 -> 2)) -> LHMap("test" -> MapClass(Map(3 -> 4)))).toDS(),
LHMapClass(LHMap(1 -> 2)) -> LHMap("test" -> MapClass(Map(3 -> 4))))
}

test("nested sequences") {
checkDataset(Seq(Seq(Seq(1))).toDS(), Seq(Seq(1)))
checkDataset(Seq(List(Queue(1))).toDS(), List(Queue(1)))
Expand Down