Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,9 @@

package org.apache.spark.sql.kafka010

import java.util.Properties
import java.util.concurrent.atomic.AtomicInteger

import org.scalatest.time.SpanSugar._
import scala.collection.mutable
import scala.util.Random

import org.apache.spark.SparkContext
import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.execution.streaming.StreamExecution
import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
import org.apache.spark.sql.streaming.{StreamTest, Trigger}
import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession}
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation
import org.apache.spark.sql.streaming.Trigger

// Run tests in KafkaSourceSuiteBase in continuous execution mode.
class KafkaContinuousSourceSuite extends KafkaSourceSuiteBase with KafkaContinuousTest
Expand Down Expand Up @@ -71,7 +60,7 @@ class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest {
eventually(timeout(streamingTimeout)) {
assert(
query.lastExecution.logical.collectFirst {
case DataSourceV2Relation(_, r: KafkaContinuousReader) => r
case StreamingDataSourceV2Relation(_, r: KafkaContinuousReader) => r
}.exists { r =>
// Ensure the new topic is present and the old topic is gone.
r.knownPartitions.exists(_.topic == topic2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.util.concurrent.atomic.AtomicInteger

import org.apache.spark.SparkContext
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd, SparkListenerTaskStart}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation
import org.apache.spark.sql.execution.streaming.StreamExecution
import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
import org.apache.spark.sql.streaming.Trigger
Expand All @@ -47,7 +47,7 @@ trait KafkaContinuousTest extends KafkaSourceTest {
eventually(timeout(streamingTimeout)) {
assert(
query.lastExecution.logical.collectFirst {
case DataSourceV2Relation(_, r: KafkaContinuousReader) => r
case StreamingDataSourceV2Relation(_, r: KafkaContinuousReader) => r
}.exists(_.knownPartitions.size == newCount),
s"query never reconfigured to $newCount partitions")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.scalatest.time.SpanSugar._

import org.apache.spark.SparkContext
import org.apache.spark.sql.{Dataset, ForeachWriter}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
import org.apache.spark.sql.functions.{count, window}
Expand Down Expand Up @@ -117,7 +117,7 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext {
} ++ (query.get.lastExecution match {
case null => Seq()
case e => e.logical.collect {
case DataSourceV2Relation(_, reader: KafkaContinuousReader) => reader
case StreamingDataSourceV2Relation(_, reader: KafkaContinuousReader) => reader
}
})
if (sources.isEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,8 @@ case class InsertIntoTable(
overwrite: Boolean,
ifPartitionNotExists: Boolean)
extends LogicalPlan {
// IF NOT EXISTS is only valid in INSERT OVERWRITE
assert(overwrite || !ifPartitionNotExists)
// overwrite=false and ifPartitionNotExists=false are used to pass mode=Ignore

// IF NOT EXISTS is only valid in static partitions
assert(partition.values.forall(_.nonEmpty) || !ifPartitionNotExists)

Expand Down
41 changes: 9 additions & 32 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.datasources.jdbc._
import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils
import org.apache.spark.sql.sources.v2._
import org.apache.spark.sql.sources.v2.{DataSourceV2, ReadSupport, ReadSupportWithSchema}
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -185,39 +185,16 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {

val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf)
if (classOf[DataSourceV2].isAssignableFrom(cls)) {
val ds = cls.newInstance()
val options = new DataSourceOptions((extraOptions ++
DataSourceV2Utils.extractSessionConfigs(
ds = ds.asInstanceOf[DataSourceV2],
conf = sparkSession.sessionState.conf)).asJava)

// Streaming also uses the data source V2 API. So it may be that the data source implements
// v2, but has no v2 implementation for batch reads. In that case, we fall back to loading
// the dataframe as a v1 source.
val reader = (ds, userSpecifiedSchema) match {
case (ds: ReadSupportWithSchema, Some(schema)) =>
ds.createReader(schema, options)

case (ds: ReadSupport, None) =>
ds.createReader(options)

case (ds: ReadSupportWithSchema, None) =>
throw new AnalysisException(s"A schema needs to be specified when using $ds.")

case (ds: ReadSupport, Some(schema)) =>
val reader = ds.createReader(options)
if (reader.readSchema() != schema) {
throw new AnalysisException(s"$ds does not allow user-specified schemas.")
}
reader

case _ => null // fall back to v1
}
val ds = cls.newInstance().asInstanceOf[DataSourceV2]
if (ds.isInstanceOf[ReadSupport] || ds.isInstanceOf[ReadSupportWithSchema]) {
val sessionOptions = DataSourceV2Utils.extractSessionConfigs(
ds = ds, conf = sparkSession.sessionState.conf)
Dataset.ofRows(sparkSession, DataSourceV2Relation(
ds, extraOptions.toMap ++ sessionOptions, path = extraOptions.get("path"),
userSchema = userSpecifiedSchema))

if (reader == null) {
loadV1Source(paths: _*)
} else {
Dataset.ofRows(sparkSession, DataSourceV2Relation(reader))
loadV1Source(paths: _*)
}
} else {
loadV1Source(paths: _*)
Expand Down
32 changes: 13 additions & 19 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@

package org.apache.spark.sql

import java.text.SimpleDateFormat
import java.util.{Date, Locale, Properties, UUID}
import java.util.{Locale, Properties}

import scala.collection.JavaConverters._

Expand All @@ -30,8 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{AnalysisBarrier, InsertIntoT
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils
import org.apache.spark.sql.execution.datasources.v2.WriteToDataSourceV2
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2Utils}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.sources.v2._
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -240,22 +238,18 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {

val cls = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf)
if (classOf[DataSourceV2].isAssignableFrom(cls)) {
val ds = cls.newInstance()
val ds = cls.newInstance().asInstanceOf[DataSourceV2]
ds match {
case ws: WriteSupport =>
val options = new DataSourceOptions((extraOptions ++
DataSourceV2Utils.extractSessionConfigs(
ds = ds.asInstanceOf[DataSourceV2],
conf = df.sparkSession.sessionState.conf)).asJava)
// Using a timestamp and a random UUID to distinguish different writing jobs. This is good
// enough as there won't be tons of writing jobs created at the same second.
val jobId = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US)
.format(new Date()) + "-" + UUID.randomUUID()
val writer = ws.createWriter(jobId, df.logicalPlan.schema, mode, options)
if (writer.isPresent) {
runCommand(df.sparkSession, "save") {
WriteToDataSourceV2(writer.get(), df.logicalPlan)
}
case _: WriteSupport =>
val sessionOptions = DataSourceV2Utils.extractSessionConfigs(
ds = ds, conf = df.sparkSession.sessionState.conf)
val relation = DataSourceV2Relation(
ds, extraOptions.toMap ++ sessionOptions, path = extraOptions.get("path"))

val (overwrite, ifNotExists) = DataSourceV2Utils.overwriteAndIfNotExists(mode)

runCommand(df.sparkSession, "save") {
InsertIntoTable(relation, Map.empty, df.logicalPlan, overwrite, ifNotExists)
}

// Streaming also uses the data source V2 API. So it may be that the data source implements
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ trait RunnableCommand extends Command {
def run(sparkSession: SparkSession): Seq[Row]
}

case class NoopCommand() extends RunnableCommand {
override def run(sparkSession: SparkSession): Seq[Row] = {
Seq.empty[Row]
}
}

/**
* A physical operator that executes the run method of a `RunnableCommand` and
* saves the result to prevent multiple executions.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2Utils, WriteToDataSourceV2}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -145,6 +146,14 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast
parts, query, overwrite, false) if parts.isEmpty =>
InsertIntoDataSourceCommand(l, query, overwrite)

case InsertIntoTable(rel: DataSourceV2Relation, _, query, overwrite, ifNotExists) =>
val writer = rel.writer(query.schema, DataSourceV2Utils.saveMode(overwrite, ifNotExists))
if (writer.isDefined) {
WriteToDataSourceV2(writer.get, query)
} else {
NoopCommand()
}

case InsertIntoDir(_, storage, provider, query, overwrite)
if provider.isDefined && provider.get.toLowerCase(Locale.ROOT) != DDLUtils.HIVE_PROVIDER =>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression,
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.InsertableRelation
import org.apache.spark.sql.types.{AtomicType, StructType}
Expand Down Expand Up @@ -392,6 +393,9 @@ case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] wit
case LogicalRelation(_: InsertableRelation, _, catalogTable, _) =>
val tblName = catalogTable.map(_.identifier.quotedString).getOrElse("unknown")
preprocess(i, tblName, Nil)
case relation: DataSourceV2Relation =>
val tableName = relation.table.map(_.toString).orElse(relation.path).getOrElse("unknown")
preprocess(i, tableName, Nil)
case _ => i
}
}
Expand Down
Loading