Skip to content
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -312,12 +312,50 @@ object ScalaReflection extends ScalaReflection {
"array",
ObjectType(classOf[Array[Any]]))

StaticInvoke(
val wrappedArray = StaticInvoke(
scala.collection.mutable.WrappedArray.getClass,
ObjectType(classOf[Seq[_]]),
"make",
array :: Nil)

if (localTypeOf[scala.collection.mutable.WrappedArray[_]] <:< t.erasure) {
wrappedArray
} else {
// Convert to another type using `to`
val cls = mirror.runtimeClass(t.typeSymbol.asClass)
import scala.collection.generic.CanBuildFrom
import scala.reflect.ClassTag

// Some canBuildFrom methods take an implicit ClassTag parameter
val cbfParams = try {
cls.getDeclaredMethod("canBuildFrom", classOf[ClassTag[_]])
StaticInvoke(
ClassTag.getClass,
ObjectType(classOf[ClassTag[_]]),
"apply",
StaticInvoke(
cls,
ObjectType(classOf[Class[_]]),
"getClass"
) :: Nil
) :: Nil
} catch {
case _: NoSuchMethodException => Nil
}

Invoke(
wrappedArray,
"to",
ObjectType(cls),
StaticInvoke(
cls,
ObjectType(classOf[CanBuildFrom[_, _, _]]),
"canBuildFrom",
cbfParams
) :: Nil
)
}

case t if t <:< localTypeOf[Map[_, _]] =>
// TODO: add walked type path for map
val TypeRef(_, _, Seq(keyType, valueType)) = t
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,37 @@ class ScalaReflectionSuite extends SparkFunSuite {
.cls.isAssignableFrom(classOf[org.apache.spark.sql.catalyst.util.GenericArrayData]))
}

test("SPARK 16792: Get correct deserializer for List[_]") {
val listDeserializer = deserializerFor[List[Int]]
assert(listDeserializer.dataType == ObjectType(classOf[List[_]]))
}

test("serialize and deserialize arbitrary sequence types") {
import scala.collection.immutable.Queue
val queueSerializer = serializerFor[Queue[Int]](BoundReference(
0, ObjectType(classOf[Queue[Int]]), nullable = false))
assert(queueSerializer.dataType.head.dataType ==
ArrayType(IntegerType, containsNull = false))
val queueDeserializer = deserializerFor[Queue[Int]]
assert(queueDeserializer.dataType == ObjectType(classOf[Queue[_]]))

import scala.collection.mutable.ArrayBuffer
val arrayBufferSerializer = serializerFor[ArrayBuffer[Int]](BoundReference(
0, ObjectType(classOf[ArrayBuffer[Int]]), nullable = false))
assert(arrayBufferSerializer.dataType.head.dataType ==
ArrayType(IntegerType, containsNull = false))
val arrayBufferDeserializer = deserializerFor[ArrayBuffer[Int]]
assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]]))

// Check whether conversion is skipped when using WrappedArray[_] supertype
// (would otherwise needlessly add overhead)
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
val seqDeserializer = deserializerFor[Seq[Int]]
assert(seqDeserializer.asInstanceOf[StaticInvoke].staticObject ==
scala.collection.mutable.WrappedArray.getClass)
assert(seqDeserializer.asInstanceOf[StaticInvoke].functionName == "make")
}

private val dataTypeForComplexData = dataTypeFor[ComplexData]
private val typeOfComplexData = typeOf[ComplexData]

Expand Down
37 changes: 24 additions & 13 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
* @since 1.6.0
*/
@InterfaceStability.Evolving
abstract class SQLImplicits {
abstract class SQLImplicits extends LowPrioritySQLImplicits {

protected def _sqlContext: SQLContext

Expand All @@ -45,9 +45,6 @@ abstract class SQLImplicits {
}
}

/** @since 1.6.0 */
implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = Encoders.product[T]

// Primitives

/** @since 1.6.0 */
Expand Down Expand Up @@ -100,31 +97,32 @@ abstract class SQLImplicits {
// Seqs

/** @since 1.6.1 */
implicit def newIntSeqEncoder: Encoder[Seq[Int]] = ExpressionEncoder()
implicit def newIntSeqEncoder[T <: Seq[Int] : TypeTag]: Encoder[T] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newLongSeqEncoder: Encoder[Seq[Long]] = ExpressionEncoder()
implicit def newLongSeqEncoder[T <: Seq[Long] : TypeTag]: Encoder[T] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newDoubleSeqEncoder: Encoder[Seq[Double]] = ExpressionEncoder()
implicit def newDoubleSeqEncoder[T <: Seq[Double] : TypeTag]: Encoder[T] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newFloatSeqEncoder: Encoder[Seq[Float]] = ExpressionEncoder()
implicit def newFloatSeqEncoder[T <: Seq[Float] : TypeTag]: Encoder[T] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newByteSeqEncoder: Encoder[Seq[Byte]] = ExpressionEncoder()
implicit def newByteSeqEncoder[T <: Seq[Byte] : TypeTag]: Encoder[T] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newShortSeqEncoder: Encoder[Seq[Short]] = ExpressionEncoder()
implicit def newShortSeqEncoder[T <: Seq[Short] : TypeTag]: Encoder[T] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newBooleanSeqEncoder: Encoder[Seq[Boolean]] = ExpressionEncoder()
implicit def newBooleanSeqEncoder[T <: Seq[Boolean] : TypeTag]: Encoder[T] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newStringSeqEncoder: Encoder[Seq[String]] = ExpressionEncoder()
implicit def newStringSeqEncoder[T <: Seq[String] : TypeTag]: Encoder[T] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newProductSeqEncoder[A <: Product : TypeTag]: Encoder[Seq[A]] = ExpressionEncoder()
implicit def newProductSeqEncoder[T <: Seq[Product] : TypeTag]: Encoder[T] =
ExpressionEncoder()

// Arrays

Expand Down Expand Up @@ -180,3 +178,16 @@ abstract class SQLImplicits {
implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name)

}

/**
* Lower priority implicit methods for converting Scala objects into [[Dataset]]s.
* Conflicting implicits are placed here to disambiguate resolution.
*
* Reasons for including specific implicits:
* newProductEncoder - to disambiguate for [[List]]s which are both [[Seq]] and [[Product]]
*/
trait LowPrioritySQLImplicits {
/** @since 1.6.0 */
implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = Encoders.product[T]

}
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,30 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
checkDataset(Seq(Array(Tuple1(1))).toDS(), Array(Tuple1(1)))
}

test("arbitrary sequences") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's also test nested sequences, e.g. List(Queue(1)), and sequences inside product, e.g. List(1) -> Queue(1)

Copy link
Contributor Author

@michalsenkyr michalsenkyr Jan 3, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some sequence-product combination tests.
Nested sequences were never supported (tried on master and 2.0.2). That would probably be worthy of another ticket.

import scala.collection.immutable.Queue
checkDataset(Seq(Queue(1)).toDS(), Queue(1))
checkDataset(Seq(Queue(1.toLong)).toDS(), Queue(1.toLong))
checkDataset(Seq(Queue(1.toDouble)).toDS(), Queue(1.toDouble))
checkDataset(Seq(Queue(1.toFloat)).toDS(), Queue(1.toFloat))
checkDataset(Seq(Queue(1.toByte)).toDS(), Queue(1.toByte))
checkDataset(Seq(Queue(1.toShort)).toDS(), Queue(1.toShort))
checkDataset(Seq(Queue(true)).toDS(), Queue(true))
checkDataset(Seq(Queue("test")).toDS(), Queue("test"))
checkDataset(Seq(Queue(Tuple1(1))).toDS(), Queue(Tuple1(1)))

import scala.collection.mutable.ArrayBuffer
checkDataset(Seq(ArrayBuffer(1)).toDS(), ArrayBuffer(1))
checkDataset(Seq(ArrayBuffer(1.toLong)).toDS(), ArrayBuffer(1.toLong))
checkDataset(Seq(ArrayBuffer(1.toDouble)).toDS(), ArrayBuffer(1.toDouble))
checkDataset(Seq(ArrayBuffer(1.toFloat)).toDS(), ArrayBuffer(1.toFloat))
checkDataset(Seq(ArrayBuffer(1.toByte)).toDS(), ArrayBuffer(1.toByte))
checkDataset(Seq(ArrayBuffer(1.toShort)).toDS(), ArrayBuffer(1.toShort))
checkDataset(Seq(ArrayBuffer(true)).toDS(), ArrayBuffer(true))
checkDataset(Seq(ArrayBuffer("test")).toDS(), ArrayBuffer("test"))
checkDataset(Seq(ArrayBuffer(Tuple1(1))).toDS(), ArrayBuffer(Tuple1(1)))
}

test("package objects") {
import packageobject._
checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1))
Expand Down