Skip to content

Commit 310454b

Browse files
DonnyZonegatorsmile
authored andcommitted
[SPARK-21739][SQL] Cast expression should initialize timezoneId when it is called statically to convert something into TimestampType
## What changes were proposed in this pull request? https://issues.apache.org/jira/projects/SPARK/issues/SPARK-21739 This issue is caused by introducing TimeZoneAwareExpression. When the **Cast** expression converts something into TimestampType, it should be resolved with setting `timezoneId`. In general, it is resolved in LogicalPlan phase. However, there are still some places that use Cast expression statically to convert datatypes without setting `timezoneId`. In such cases, `NoSuchElementException: None.get` will be thrown for TimestampType. This PR is proposed to fix the issue. We have checked the whole project and found two such usages(i.e., in`TableReader` and `HiveTableScanExec`). ## How was this patch tested? unit test Author: donnyzone <wellfengzhu@gmail.com> Closes #18960 from DonnyZone/spark-21739.
1 parent 2caaed9 commit 310454b

3 files changed

Lines changed: 29 additions & 4 deletions

File tree

sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,10 @@ import org.apache.spark.internal.Logging
3939
import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD}
4040
import org.apache.spark.sql.SparkSession
4141
import org.apache.spark.sql.catalyst.InternalRow
42+
import org.apache.spark.sql.catalyst.analysis.CastSupport
4243
import org.apache.spark.sql.catalyst.expressions._
4344
import org.apache.spark.sql.catalyst.util.DateTimeUtils
45+
import org.apache.spark.sql.internal.SQLConf
4446
import org.apache.spark.unsafe.types.UTF8String
4547
import org.apache.spark.util.{SerializableConfiguration, Utils}
4648

@@ -65,7 +67,7 @@ class HadoopTableReader(
6567
@transient private val tableDesc: TableDesc,
6668
@transient private val sparkSession: SparkSession,
6769
hadoopConf: Configuration)
68-
extends TableReader with Logging {
70+
extends TableReader with CastSupport with Logging {
6971

7072
// Hadoop honors "mapreduce.job.maps" as hint,
7173
// but will ignore when mapreduce.jobtracker.address is "local".
@@ -86,6 +88,8 @@ class HadoopTableReader(
8688
private val _broadcastedHadoopConf =
8789
sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
8890

91+
override def conf: SQLConf = sparkSession.sessionState.conf
92+
8993
override def makeRDDForTable(hiveTable: HiveTable): RDD[InternalRow] =
9094
makeRDDForTable(
9195
hiveTable,
@@ -227,7 +231,7 @@ class HadoopTableReader(
227231
def fillPartitionKeys(rawPartValues: Array[String], row: InternalRow): Unit = {
228232
partitionKeyAttrs.foreach { case (attr, ordinal) =>
229233
val partOrdinal = partitionKeys.indexOf(attr)
230-
row(ordinal) = Cast(Literal(rawPartValues(partOrdinal)), attr.dataType).eval(null)
234+
row(ordinal) = cast(Literal(rawPartValues(partOrdinal)), attr.dataType).eval(null)
231235
}
232236
}
233237

sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,15 @@ import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils
3030
import org.apache.spark.rdd.RDD
3131
import org.apache.spark.sql.SparkSession
3232
import org.apache.spark.sql.catalyst.InternalRow
33+
import org.apache.spark.sql.catalyst.analysis.CastSupport
3334
import org.apache.spark.sql.catalyst.catalog.HiveTableRelation
3435
import org.apache.spark.sql.catalyst.expressions._
3536
import org.apache.spark.sql.catalyst.plans.QueryPlan
3637
import org.apache.spark.sql.execution._
3738
import org.apache.spark.sql.execution.metric.SQLMetrics
3839
import org.apache.spark.sql.hive._
3940
import org.apache.spark.sql.hive.client.HiveClientImpl
41+
import org.apache.spark.sql.internal.SQLConf
4042
import org.apache.spark.sql.types.{BooleanType, DataType}
4143
import org.apache.spark.util.Utils
4244

@@ -53,11 +55,13 @@ case class HiveTableScanExec(
5355
relation: HiveTableRelation,
5456
partitionPruningPred: Seq[Expression])(
5557
@transient private val sparkSession: SparkSession)
56-
extends LeafExecNode {
58+
extends LeafExecNode with CastSupport {
5759

5860
require(partitionPruningPred.isEmpty || relation.isPartitioned,
5961
"Partition pruning predicates only supported for partitioned tables.")
6062

63+
override def conf: SQLConf = sparkSession.sessionState.conf
64+
6165
override lazy val metrics = Map(
6266
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
6367

@@ -104,7 +108,7 @@ case class HiveTableScanExec(
104108
hadoopConf)
105109

106110
private def castFromString(value: String, dataType: DataType) = {
107-
Cast(Literal(value), dataType).eval(null)
111+
cast(Literal(value), dataType).eval(null)
108112
}
109113

110114
private def addColumnMetadataToConf(hiveConf: Configuration): Unit = {

sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.hive
1919

2020
import java.io.File
21+
import java.sql.Timestamp
2122

2223
import com.google.common.io.Files
2324
import org.apache.hadoop.fs.FileSystem
@@ -68,4 +69,20 @@ class QueryPartitionSuite extends QueryTest with SQLTestUtils with TestHiveSingl
6869
sql("DROP TABLE IF EXISTS createAndInsertTest")
6970
}
7071
}
72+
73+
test("SPARK-21739: Cast expression should initialize timezoneId") {
74+
withTable("table_with_timestamp_partition") {
75+
sql("CREATE TABLE table_with_timestamp_partition(value int) PARTITIONED BY (ts TIMESTAMP)")
76+
sql("INSERT OVERWRITE TABLE table_with_timestamp_partition " +
77+
"PARTITION (ts = '2010-01-01 00:00:00.000') VALUES (1)")
78+
79+
// test for Cast expression in TableReader
80+
checkAnswer(sql("SELECT * FROM table_with_timestamp_partition"),
81+
Seq(Row(1, Timestamp.valueOf("2010-01-01 00:00:00.000"))))
82+
83+
// test for Cast expression in HiveTableScanExec
84+
checkAnswer(sql("SELECT value FROM table_with_timestamp_partition " +
85+
"WHERE ts = '2010-01-01 00:00:00.000'"), Row(1))
86+
}
87+
}
7188
}

0 commit comments

Comments
 (0)