diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/expressions/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/expressions/expressions.scala index 2d4d6e7c6d5ee..ea5fc05dd5ff3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/expressions/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/expressions/expressions.scala @@ -94,6 +94,17 @@ private[sql] final case class BucketTransform( override def toString: String = describe } +private[sql] object BucketTransform { + def unapply(transform: Transform): Option[(Int, NamedReference)] = transform match { + case NamedTransform("bucket", Seq( + Lit(value: Int, IntegerType), + Ref(seq: Seq[String]))) => + Some((value, FieldReference(seq))) + case _ => + None + } +} + private[sql] final case class ApplyTransform( name: String, args: Seq[Expression]) extends Transform { @@ -111,32 +122,104 @@ private[sql] final case class ApplyTransform( override def toString: String = describe } +/** + * Convenience extractor for any Literal. + */ +private object Lit { + def unapply[T](literal: Literal[T]): Some[(T, DataType)] = { + Some((literal.value, literal.dataType)) + } +} + +/** + * Convenience extractor for any NamedReference. + */ +private object Ref { + def unapply(named: NamedReference): Some[Seq[String]] = { + Some(named.fieldNames) + } +} + +/** + * Convenience extractor for any Transform. + */ +private object NamedTransform { + def unapply(transform: Transform): Some[(String, Seq[Expression])] = { + Some((transform.name, transform.arguments)) + } +} + private[sql] final case class IdentityTransform( ref: NamedReference) extends SingleColumnTransform(ref) { override val name: String = "identity" override def describe: String = ref.describe } +private[sql] object IdentityTransform { + def unapply(transform: Transform): Option[FieldReference] = transform match { + case NamedTransform("identity", Seq(Ref(parts))) => + Some(FieldReference(parts)) + case _ => + None + } +} + private[sql] final case class YearsTransform( ref: NamedReference) extends SingleColumnTransform(ref) { override val name: String = "years" } +private[sql] object YearsTransform { + def unapply(transform: Transform): Option[FieldReference] = transform match { + case NamedTransform("years", Seq(Ref(parts))) => + Some(FieldReference(parts)) + case _ => + None + } +} + private[sql] final case class MonthsTransform( ref: NamedReference) extends SingleColumnTransform(ref) { override val name: String = "months" } +private[sql] object MonthsTransform { + def unapply(transform: Transform): Option[FieldReference] = transform match { + case NamedTransform("months", Seq(Ref(parts))) => + Some(FieldReference(parts)) + case _ => + None + } +} + private[sql] final case class DaysTransform( ref: NamedReference) extends SingleColumnTransform(ref) { override val name: String = "days" } +private[sql] object DaysTransform { + def unapply(transform: Transform): Option[FieldReference] = transform match { + case NamedTransform("days", Seq(Ref(parts))) => + Some(FieldReference(parts)) + case _ => + None + } +} + private[sql] final case class HoursTransform( ref: NamedReference) extends SingleColumnTransform(ref) { override val name: String = "hours" } +private[sql] object HoursTransform { + def unapply(transform: Transform): Option[FieldReference] = transform match { + case NamedTransform("hours", Seq(Ref(parts))) => + Some(FieldReference(parts)) + case _ => + None + } +} + private[sql] final case class LiteralValue[T](value: T, dataType: DataType) extends Literal[T] { override def describe: String = { if (dataType.isInstanceOf[StringType]) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalog/v2/expressions/TransformExtractorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalog/v2/expressions/TransformExtractorSuite.scala new file mode 100644 index 0000000000000..c0a5dada19dba --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalog/v2/expressions/TransformExtractorSuite.scala @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst +import org.apache.spark.sql.types.DataType + +class TransformExtractorSuite extends SparkFunSuite { + /** + * Creates a Literal using an anonymous class. + */ + private def lit[T](literal: T): Literal[T] = new Literal[T] { + override def value: T = literal + override def dataType: DataType = catalyst.expressions.Literal(literal).dataType + override def describe: String = literal.toString + } + + /** + * Creates a NamedReference using an anonymous class. + */ + private def ref(names: String*): NamedReference = new NamedReference { + override def fieldNames: Array[String] = names.toArray + override def describe: String = names.mkString(".") + } + + /** + * Creates a Transform using an anonymous class. + */ + private def transform(func: String, ref: NamedReference): Transform = new Transform { + override def name: String = func + override def references: Array[NamedReference] = Array(ref) + override def arguments: Array[Expression] = Array(ref) + override def describe: String = ref.describe + } + + test("Identity extractor") { + transform("identity", ref("a", "b")) match { + case IdentityTransform(FieldReference(seq)) => + assert(seq === Seq("a", "b")) + case _ => + fail("Did not match IdentityTransform extractor") + } + + transform("unknown", ref("a", "b")) match { + case IdentityTransform(FieldReference(_)) => + fail("Matched unknown transform") + case _ => + // expected + } + } + + test("Years extractor") { + transform("years", ref("a", "b")) match { + case YearsTransform(FieldReference(seq)) => + assert(seq === Seq("a", "b")) + case _ => + fail("Did not match YearsTransform extractor") + } + + transform("unknown", ref("a", "b")) match { + case YearsTransform(FieldReference(_)) => + fail("Matched unknown transform") + case _ => + // expected + } + } + + test("Months extractor") { + transform("months", ref("a", "b")) match { + case MonthsTransform(FieldReference(seq)) => + assert(seq === Seq("a", "b")) + case _ => + fail("Did not match MonthsTransform extractor") + } + + transform("unknown", ref("a", "b")) match { + case MonthsTransform(FieldReference(_)) => + fail("Matched unknown transform") + case _ => + // expected + } + } + + test("Days extractor") { + transform("days", ref("a", "b")) match { + case DaysTransform(FieldReference(seq)) => + assert(seq === Seq("a", "b")) + case _ => + fail("Did not match DaysTransform extractor") + } + + transform("unknown", ref("a", "b")) match { + case DaysTransform(FieldReference(_)) => + fail("Matched unknown transform") + case _ => + // expected + } + } + + test("Hours extractor") { + transform("hours", ref("a", "b")) match { + case HoursTransform(FieldReference(seq)) => + assert(seq === Seq("a", "b")) + case _ => + fail("Did not match HoursTransform extractor") + } + + transform("unknown", ref("a", "b")) match { + case HoursTransform(FieldReference(_)) => + fail("Matched unknown transform") + case _ => + // expected + } + } + + test("Bucket extractor") { + val col = ref("a", "b") + val bucketTransform = new Transform { + override def name: String = "bucket" + override def references: Array[NamedReference] = Array(col) + override def arguments: Array[Expression] = Array(lit(16), col) + override def describe: String = s"bucket(16, ${col.describe})" + } + + bucketTransform match { + case BucketTransform(numBuckets, FieldReference(seq)) => + assert(numBuckets === 16) + assert(seq === Seq("a", "b")) + case _ => + fail("Did not match BucketTransform extractor") + } + + transform("unknown", ref("a", "b")) match { + case BucketTransform(_, _) => + fail("Matched unknown transform") + case _ => + // expected + } + } +}