Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ class Analyzer(
GlobalAggregates ::
ResolveAggregateFunctions ::
TimeWindowing ::
ResolveInlineTables ::
ResolveInlineTables(conf) ::
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to fix this bug by re-order analyzer rules?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried before. But the resolution rules will add new timezone-aware expressions, so it still needs the rule to resolve timezone-aware expressions after resolution rules.

TypeCoercion.typeCoercionRules ++
extendedResolutionRules : _*),
Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@ package org.apache.spark.sql.catalyst.analysis

import scala.util.control.NonFatal

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Cast
import org.apache.spark.sql.catalyst.{CatalystConf, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{Cast, TimeZoneAwareExpression}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types.{StructField, StructType}

/**
* An analyzer rule that replaces [[UnresolvedInlineTable]] with [[LocalRelation]].
*/
object ResolveInlineTables extends Rule[LogicalPlan] {
case class ResolveInlineTables(conf: CatalystConf) extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case table: UnresolvedInlineTable if table.expressionsResolved =>
validateInputDimension(table)
Expand Down Expand Up @@ -95,10 +95,14 @@ object ResolveInlineTables extends Rule[LogicalPlan] {
InternalRow.fromSeq(row.zipWithIndex.map { case (e, ci) =>
val targetType = fields(ci).dataType
try {
if (e.dataType.sameType(targetType)) {
e.eval()
val castedExpr = if (e.dataType.sameType(targetType)) {
e
} else {
Cast(e, targetType).eval()
Cast(e, targetType)
}
castedExpr match {
case te: TimeZoneAwareExpression => te.withTimeZone(conf.sessionLocalTimeZone).eval()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we traverse the entire tree expression to check for time zone aware expressions?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or did ResolveTimeZone already process the input expression? In that case just add the timezone to the cast.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should. ResolveTimeZone is performed after all resolution rules finish.

case _ => castedExpr.eval()
}
} catch {
case NonFatal(ex) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,82 +20,92 @@ package org.apache.spark.sql.catalyst.analysis
import org.scalatest.BeforeAndAfter

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Literal, Rand}
import org.apache.spark.sql.catalyst.expressions.{Cast, Literal, Rand}
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.types.{LongType, NullType}
import org.apache.spark.sql.types.{LongType, NullType, TimestampType}

/**
* Unit tests for [[ResolveInlineTables]]. Note that there are also test cases defined in
* end-to-end tests (in sql/core module) for verifying the correct error messages are shown
* in negative cases.
*/
class ResolveInlineTablesSuite extends PlanTest with BeforeAndAfter {
class ResolveInlineTablesSuite extends AnalysisTest with BeforeAndAfter {

private def lit(v: Any): Literal = Literal(v)

test("validate inputs are foldable") {
ResolveInlineTables.validateInputEvaluable(
ResolveInlineTables(conf).validateInputEvaluable(
UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)))))

// nondeterministic (rand) should not work
intercept[AnalysisException] {
ResolveInlineTables.validateInputEvaluable(
ResolveInlineTables(conf).validateInputEvaluable(
UnresolvedInlineTable(Seq("c1"), Seq(Seq(Rand(1)))))
}

// aggregate should not work
intercept[AnalysisException] {
ResolveInlineTables.validateInputEvaluable(
ResolveInlineTables(conf).validateInputEvaluable(
UnresolvedInlineTable(Seq("c1"), Seq(Seq(Count(lit(1))))))
}

// unresolved attribute should not work
intercept[AnalysisException] {
ResolveInlineTables.validateInputEvaluable(
ResolveInlineTables(conf).validateInputEvaluable(
UnresolvedInlineTable(Seq("c1"), Seq(Seq(UnresolvedAttribute("A")))))
}
}

test("validate input dimensions") {
ResolveInlineTables.validateInputDimension(
ResolveInlineTables(conf).validateInputDimension(
UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2)))))

// num alias != data dimension
intercept[AnalysisException] {
ResolveInlineTables.validateInputDimension(
ResolveInlineTables(conf).validateInputDimension(
UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)), Seq(lit(2)))))
}

// num alias == data dimension, but data themselves are inconsistent
intercept[AnalysisException] {
ResolveInlineTables.validateInputDimension(
ResolveInlineTables(conf).validateInputDimension(
UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(21), lit(22)))))
}
}

test("do not fire the rule if not all expressions are resolved") {
val table = UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(UnresolvedAttribute("A"))))
assert(ResolveInlineTables(table) == table)
assert(ResolveInlineTables(conf)(table) == table)
}

test("convert") {
val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L))))
val converted = ResolveInlineTables.convert(table)
val converted = ResolveInlineTables(conf).convert(table)

assert(converted.output.map(_.dataType) == Seq(LongType))
assert(converted.data.size == 2)
assert(converted.data(0).getLong(0) == 1L)
assert(converted.data(1).getLong(0) == 2L)
}

test("convert TimeZoneAwareExpression") {
val table = UnresolvedInlineTable(Seq("c1"),
Seq(Seq(Cast(lit("1991-12-06 00:00:00.0"), TimestampType))))
val converted = ResolveInlineTables(conf).convert(table)
val correct = Cast(lit("1991-12-06 00:00:00.0"), TimestampType)
.withTimeZone(conf.sessionLocalTimeZone).eval().asInstanceOf[Long]
assert(converted.output.map(_.dataType) == Seq(TimestampType))
assert(converted.data.size == 1)
assert(converted.data(0).getLong(0) == correct)
}

test("nullability inference in convert") {
val table1 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L))))
val converted1 = ResolveInlineTables.convert(table1)
val converted1 = ResolveInlineTables(conf).convert(table1)
assert(!converted1.schema.fields(0).nullable)

val table2 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(Literal(null, NullType))))
val converted2 = ResolveInlineTables.convert(table2)
val converted2 = ResolveInlineTables(conf).convert(table2)
assert(converted2.schema.fields(0).nullable)
}
}
3 changes: 3 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/inline-table.sql
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ select * from values ("one", array(0, 1)), ("two", array(2, 3)) as data(a, b);
-- decimal and double coercion
select * from values ("one", 2.0), ("two", 3.0D) as data(a, b);

-- string to timestamp
select * from values timestamp('1991-12-06 00:00:00.0') as data(a);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you put this at the end of the file? That reduces the size of the diff.


-- error reporting: nondeterministic function rand
select * from values ("one", rand(5)), ("two", 3.0D) as data(a, b);

Expand Down
36 changes: 22 additions & 14 deletions sql/core/src/test/resources/sql-tests/results/inline-table.sql.out
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 16
-- Number of queries: 17


-- !query 0
Expand Down Expand Up @@ -92,54 +92,62 @@ two 3.0


-- !query 10
select * from values ("one", rand(5)), ("two", 3.0D) as data(a, b)
select * from values timestamp('1991-12-06 00:00:00.0') as data(a)
-- !query 10 schema
struct<>
struct<a:timestamp>
-- !query 10 output
org.apache.spark.sql.AnalysisException
cannot evaluate expression rand(5) in inline table definition; line 1 pos 29
1991-12-06 00:00:00


-- !query 11
select * from values ("one", 2.0), ("two") as data(a, b)
select * from values ("one", rand(5)), ("two", 3.0D) as data(a, b)
-- !query 11 schema
struct<>
-- !query 11 output
org.apache.spark.sql.AnalysisException
expected 2 columns but found 1 columns in row 1; line 1 pos 14
cannot evaluate expression rand(5) in inline table definition; line 1 pos 29


-- !query 12
select * from values ("one", array(0, 1)), ("two", struct(1, 2)) as data(a, b)
select * from values ("one", 2.0), ("two") as data(a, b)
-- !query 12 schema
struct<>
-- !query 12 output
org.apache.spark.sql.AnalysisException
incompatible types found in column b for inline table; line 1 pos 14
expected 2 columns but found 1 columns in row 1; line 1 pos 14


-- !query 13
select * from values ("one"), ("two") as data(a, b)
select * from values ("one", array(0, 1)), ("two", struct(1, 2)) as data(a, b)
-- !query 13 schema
struct<>
-- !query 13 output
org.apache.spark.sql.AnalysisException
expected 2 columns but found 1 columns in row 0; line 1 pos 14
incompatible types found in column b for inline table; line 1 pos 14


-- !query 14
select * from values ("one", random_not_exist_func(1)), ("two", 2) as data(a, b)
select * from values ("one"), ("two") as data(a, b)
-- !query 14 schema
struct<>
-- !query 14 output
org.apache.spark.sql.AnalysisException
Undefined function: 'random_not_exist_func'. This function is neither a registered temporary function nor a permanent function registered in the database 'default'.; line 1 pos 29
expected 2 columns but found 1 columns in row 0; line 1 pos 14


-- !query 15
select * from values ("one", count(1)), ("two", 2) as data(a, b)
select * from values ("one", random_not_exist_func(1)), ("two", 2) as data(a, b)
-- !query 15 schema
struct<>
-- !query 15 output
org.apache.spark.sql.AnalysisException
Undefined function: 'random_not_exist_func'. This function is neither a registered temporary function nor a permanent function registered in the database 'default'.; line 1 pos 29


-- !query 16
select * from values ("one", count(1)), ("two", 2) as data(a, b)
-- !query 16 schema
struct<>
-- !query 16 output
org.apache.spark.sql.AnalysisException
cannot evaluate expression count(1) in inline table definition; line 1 pos 29