Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.jdbc
import java.math.{BigDecimal => JBigDecimal}
import java.sql.{Connection, Date, Timestamp}
import java.text.SimpleDateFormat
import java.time.LocalDateTime
import java.util.Properties

import org.apache.spark.sql.Column
Expand Down Expand Up @@ -140,6 +141,12 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
"c0 money)").executeUpdate()
conn.prepareStatement("INSERT INTO money_types VALUES " +
"('$1,000.00')").executeUpdate()

conn.prepareStatement(s"CREATE TABLE timestamp_ntz(v timestamp)").executeUpdate()
conn.prepareStatement(s"""INSERT INTO timestamp_ntz VALUES
|('2013-04-05 12:01:02'),
|('2013-04-05 18:01:02.123'),
|('2013-04-05 18:01:02.123456')""".stripMargin).executeUpdate()
}

test("Type mapping for various types") {
Expand Down Expand Up @@ -381,4 +388,32 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
assert(row(0).length === 1)
assert(row(0).getString(0) === "$1,000.00")
}

test("SPARK-43040: timestamp_ntz read test") {
val prop = new Properties
prop.setProperty("preferTimestampNTZ", "true")
val df = sqlContext.read.jdbc(jdbcUrl, "timestamp_ntz", prop)
val row = df.collect()
assert(row.length === 3)
assert(row(0).length === 1)
assert(row(0) === Row(LocalDateTime.of(2013, 4, 5, 12, 1, 2)))
assert(row(1) === Row(LocalDateTime.of(2013, 4, 5, 18, 1, 2, 123000000)))
assert(row(2) === Row(LocalDateTime.of(2013, 4, 5, 18, 1, 2, 123456000)))
}

test("SPARK-43040: timestamp_ntz roundtrip test") {
val prop = new Properties
prop.setProperty("preferTimestampNTZ", "true")

val sparkQuery = """
|select
| timestamp_ntz'2020-12-10 11:22:33' as col0
""".stripMargin

val df_expected = sqlContext.sql(sparkQuery)
df_expected.write.jdbc(jdbcUrl, "timestamp_ntz_roundtrip", prop)

val df_actual = sqlContext.read.jdbc(jdbcUrl, "timestamp_ntz_roundtrip", prop)
assert(df_actual.collect()(0) == df_expected.collect()(0))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,8 @@ private[jdbc] class JDBCRDD(
stmt.setFetchSize(options.fetchSize)
stmt.setQueryTimeout(options.queryTimeout)
rs = stmt.executeQuery()
val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics)
val rowsIterator =
JdbcUtils.resultSetToSparkInternalRows(rs, dialect, schema, inputMetrics)

CompletionIterator[InternalRow, Iterator[InternalRow]](
new InterruptibleIterator(context, rowsIterator), close())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils, DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateTimeToMicros, localDateToDays, toJavaDate, toJavaTimestamp, toJavaTimestampNoRebase}
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateToDays, toJavaDate, toJavaTimestamp}
import org.apache.spark.sql.connector.catalog.{Identifier, TableChange}
import org.apache.spark.sql.connector.catalog.index.{SupportsIndex, TableIndex}
import org.apache.spark.sql.connector.expressions.NamedReference
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType}
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType, NoopDialect}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SchemaUtils
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -316,21 +316,31 @@ object JdbcUtils extends Logging with SQLConfHelper {
/**
* Convert a [[ResultSet]] into an iterator of Catalyst Rows.
*/
def resultSetToRows(resultSet: ResultSet, schema: StructType): Iterator[Row] = {
def resultSetToRows(
resultSet: ResultSet,
schema: StructType): Iterator[Row] = {
resultSetToRows(resultSet, schema, NoopDialect)
}

def resultSetToRows(
resultSet: ResultSet,
schema: StructType,
dialect: JdbcDialect): Iterator[Row] = {
val inputMetrics =
Option(TaskContext.get()).map(_.taskMetrics().inputMetrics).getOrElse(new InputMetrics)
val fromRow = RowEncoder(schema).resolveAndBind().createDeserializer()
val internalRows = resultSetToSparkInternalRows(resultSet, schema, inputMetrics)
val internalRows = resultSetToSparkInternalRows(resultSet, dialect, schema, inputMetrics)
internalRows.map(fromRow)
}

private[spark] def resultSetToSparkInternalRows(
resultSet: ResultSet,
dialect: JdbcDialect,
schema: StructType,
inputMetrics: InputMetrics): Iterator[InternalRow] = {
new NextIterator[InternalRow] {
private[this] val rs = resultSet
private[this] val getters: Array[JDBCValueGetter] = makeGetters(schema)
private[this] val getters: Array[JDBCValueGetter] = makeGetters(dialect, schema)
private[this] val mutableRow = new SpecificInternalRow(schema.fields.map(x => x.dataType))

override protected def close(): Unit = {
Expand Down Expand Up @@ -368,12 +378,17 @@ object JdbcUtils extends Logging with SQLConfHelper {
* Creates `JDBCValueGetter`s according to [[StructType]], which can set
* each value from `ResultSet` to each field of [[InternalRow]] correctly.
*/
private def makeGetters(schema: StructType): Array[JDBCValueGetter] = {
private def makeGetters(
dialect: JdbcDialect,
schema: StructType): Array[JDBCValueGetter] = {
val replaced = CharVarcharUtils.replaceCharVarcharWithStringInSchema(schema)
replaced.fields.map(sf => makeGetter(sf.dataType, sf.metadata))
replaced.fields.map(sf => makeGetter(sf.dataType, dialect, sf.metadata))
}

private def makeGetter(dt: DataType, metadata: Metadata): JDBCValueGetter = dt match {
private def makeGetter(
dt: DataType,
dialect: JdbcDialect,
metadata: Metadata): JDBCValueGetter = dt match {
case BooleanType =>
(rs: ResultSet, row: InternalRow, pos: Int) =>
row.setBoolean(pos, rs.getBoolean(pos + 1))
Expand Down Expand Up @@ -478,7 +493,8 @@ object JdbcUtils extends Logging with SQLConfHelper {
(rs: ResultSet, row: InternalRow, pos: Int) =>
val t = rs.getTimestamp(pos + 1)
if (t != null) {
row.setLong(pos, DateTimeUtils.fromJavaTimestampNoRebase(t))
row.setLong(pos,
DateTimeUtils.localDateTimeToMicros(dialect.convertJavaTimestampToTimestampNTZ(t)))
} else {
row.update(pos, null)
}
Expand Down Expand Up @@ -596,8 +612,8 @@ object JdbcUtils extends Logging with SQLConfHelper {

case TimestampNTZType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
val micros = localDateTimeToMicros(row.getAs[java.time.LocalDateTime](pos))
stmt.setTimestamp(pos + 1, toJavaTimestampNoRebase(micros))
stmt.setTimestamp(pos + 1,
dialect.convertTimestampNTZToJavaTimestamp(row.getAs[java.time.LocalDateTime](pos)))

case DateType =>
if (conf.datetimeJava8ApiEnabled) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.jdbc

import java.sql.{Connection, Date, Driver, Statement, Timestamp}
import java.time.{Instant, LocalDate}
import java.time.{Instant, LocalDate, LocalDateTime}
import java.util

import scala.collection.mutable.ArrayBuilder
Expand All @@ -31,6 +31,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{localDateTimeToMicros, toJavaTimestampNoRebase}
import org.apache.spark.sql.connector.catalog.{Identifier, TableChange}
import org.apache.spark.sql.connector.catalog.TableChange._
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction
Expand Down Expand Up @@ -104,6 +105,31 @@ abstract class JdbcDialect extends Serializable with Logging {
*/
def getJDBCType(dt: DataType): Option[JdbcType] = None

/**
* Convert java.sql.Timestamp to a LocalDateTime representing the same wall-clock time as the
* value stored in a remote database.
* JDBC dialects should override this function to provide implementations that suite their
* JDBC drivers.
* @param t Timestamp returned from JDBC driver getTimestamp method.
* @return A LocalDateTime representing the same wall clock time as the timestamp in database.
*/
@Since("3.5.0")
def convertJavaTimestampToTimestampNTZ(t: Timestamp): LocalDateTime = {
DateTimeUtils.microsToLocalDateTime(DateTimeUtils.fromJavaTimestampNoRebase(t))
}

/**
* Converts a LocalDateTime representing a TimestampNTZ type to an
* instance of `java.sql.Timestamp`.
* @param ldt representing a TimestampNTZType.
* @return A Java Timestamp representing this LocalDateTime.
*/
@Since("3.5.0")
def convertTimestampNTZToJavaTimestamp(ldt: LocalDateTime): Timestamp = {
val micros = localDateTimeToMicros(ldt)
toJavaTimestampNoRebase(micros)
}

/**
* Returns a factory for creating connections to the given JDBC URL.
* In general, creating a connection has nothing to do with JDBC partition id.
Expand Down Expand Up @@ -682,6 +708,6 @@ object JdbcDialects {
/**
* NOOP dialect object, always returning the neutral element.
*/
private object NoopDialect extends JdbcDialect {
object NoopDialect extends JdbcDialect {
override def canHandle(url : String): Boolean = true
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

package org.apache.spark.sql.jdbc

import java.sql.{Connection, SQLException, Types}
import java.sql.{Connection, SQLException, Timestamp, Types}
import java.time.LocalDateTime
import java.util
import java.util.Locale

Expand Down Expand Up @@ -102,6 +103,14 @@ private object PostgresDialect extends JdbcDialect with SQLConfHelper {
case _ => None
}

override def convertJavaTimestampToTimestampNTZ(t: Timestamp): LocalDateTime = {
t.toLocalDateTime
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After the change refer: https://github.com/apache/spark/pull/40678/files#r1162437868
We can update with DateTimeUtils.localDateTimeToMicros(t.toLocalDateTime) here.

}

override def convertTimestampNTZToJavaTimestamp(ldt: LocalDateTime): Timestamp = {
Timestamp.valueOf(ldt)
}

override def getJDBCType(dt: DataType): Option[JdbcType] = dt match {
case StringType => Some(JdbcType("TEXT", Types.VARCHAR))
case BinaryType => Some(JdbcType("BYTEA", Types.BINARY))
Expand Down