Skip to content

Commit 65a0d43

Browse files
committed
[SPARK-28178][SQL] DataSourceV2: DataFrameWriter.insertInfo
1 parent 3de4e1b commit 65a0d43

File tree

2 files changed

+159
-2
lines changed

2 files changed

+159
-2
lines changed

sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,19 @@ import java.util.{Locale, Properties, UUID}
2222
import scala.collection.JavaConverters._
2323

2424
import org.apache.spark.annotation.Stable
25+
import org.apache.spark.sql.catalog.v2.{CatalogPlugin, Identifier}
2526
import org.apache.spark.sql.catalyst.TableIdentifier
2627
import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation}
2728
import org.apache.spark.sql.catalyst.catalog._
2829
import org.apache.spark.sql.catalyst.expressions.Literal
29-
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, InsertIntoTable, LogicalPlan, OverwriteByExpression}
30+
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, InsertIntoTable, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic}
3031
import org.apache.spark.sql.execution.SQLExecution
3132
import org.apache.spark.sql.execution.command.DDLUtils
3233
import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, DataSourceUtils, LogicalRelation}
3334
import org.apache.spark.sql.execution.datasources.v2._
35+
import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode
3436
import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister}
37+
import org.apache.spark.sql.sources.BaseRelation
3538
import org.apache.spark.sql.sources.v2._
3639
import org.apache.spark.sql.sources.v2.TableCapability._
3740
import org.apache.spark.sql.types.StructType
@@ -356,7 +359,54 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
356359
* @since 1.4.0
357360
*/
358361
def insertInto(tableName: String): Unit = {
359-
insertInto(df.sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName))
362+
import df.sparkSession.sessionState.analyzer.{AsTableIdentifier, CatalogObjectIdentifier}
363+
364+
df.sparkSession.sessionState.sqlParser.parseMultipartIdentifier(tableName) match {
365+
case CatalogObjectIdentifier(Some(catalog), ident) =>
366+
insertInto(catalog, ident)
367+
case AsTableIdentifier(tableIdentifier) =>
368+
insertInto(tableIdentifier)
369+
}
370+
}
371+
372+
private def insertInto(catalog: CatalogPlugin, ident: Identifier): Unit = {
373+
import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._
374+
375+
assertNotBucketed("insertInto")
376+
377+
if (partitioningColumns.isDefined) {
378+
throw new AnalysisException(
379+
"insertInto() can't be used together with partitionBy(). " +
380+
"Partition columns have already been defined for the table. " +
381+
"It is not necessary to use partitionBy()."
382+
)
383+
}
384+
385+
val table = DataSourceV2Relation.create(catalog.asTableCatalog.loadTable(ident))
386+
387+
val command = modeForDSV2 match {
388+
case SaveMode.Append =>
389+
AppendData.byName(table, df.logicalPlan)
390+
391+
case SaveMode.Overwrite =>
392+
val conf = df.sparkSession.sessionState.conf
393+
val dynamicPartitionOverwrite = table.table.partitioning.size > 0 &&
394+
conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC
395+
396+
if (dynamicPartitionOverwrite) {
397+
OverwritePartitionsDynamic.byName(table, df.logicalPlan)
398+
} else {
399+
OverwriteByExpression.byName(table, df.logicalPlan, Literal(true))
400+
}
401+
402+
case other =>
403+
throw new AnalysisException(s"insertInto does not support $other mode, " +
404+
s"please use Append or Overwrite mode instead.")
405+
}
406+
407+
runCommand(df.sparkSession, "insertInto") {
408+
command
409+
}
360410
}
361411

362412
private def insertInto(tableIdent: TableIdentifier): Unit = {
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.sources.v2
19+
20+
import org.scalatest.BeforeAndAfter
21+
22+
import org.apache.spark.sql.QueryTest
23+
import org.apache.spark.sql.internal.SQLConf.{PARTITION_OVERWRITE_MODE, PartitionOverwriteMode}
24+
import org.apache.spark.sql.test.SharedSQLContext
25+
26+
class DataSourceV2DataFrameSuite extends QueryTest with SharedSQLContext with BeforeAndAfter {
27+
import testImplicits._
28+
29+
before {
30+
spark.conf.set("spark.sql.catalog.testcat", classOf[TestInMemoryTableCatalog].getName)
31+
spark.conf.set("spark.sql.catalog.testcat2", classOf[TestInMemoryTableCatalog].getName)
32+
33+
val df = spark.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, "c"))).toDF("id", "data")
34+
df.createOrReplaceTempView("source")
35+
val df2 = spark.createDataFrame(Seq((4L, "d"), (5L, "e"), (6L, "f"))).toDF("id", "data")
36+
df2.createOrReplaceTempView("source2")
37+
}
38+
39+
after {
40+
spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog].clearTables()
41+
spark.sql("DROP VIEW source")
42+
spark.sql("DROP VIEW source2")
43+
}
44+
45+
test("insertInto: append") {
46+
val t1 = "testcat.ns1.ns2.tbl"
47+
withTable(t1) {
48+
sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo")
49+
spark.table("source").select("id", "data").write.insertInto(t1)
50+
checkAnswer(spark.table(t1), spark.table("source"))
51+
}
52+
}
53+
54+
test("insertInto: append - across catalog") {
55+
val t1 = "testcat.ns1.ns2.tbl"
56+
val t2 = "testcat2.db.tbl"
57+
withTable(t1, t2) {
58+
sql(s"CREATE TABLE $t1 USING foo AS TABLE source")
59+
sql(s"CREATE TABLE $t2 (id bigint, data string) USING foo")
60+
spark.table(t1).write.insertInto(t2)
61+
checkAnswer(spark.table(t2), spark.table("source"))
62+
}
63+
}
64+
65+
test("insertInto: append partitioned table - dynamic clause") {
66+
val t1 = "testcat.ns1.ns2.tbl"
67+
withTable(t1) {
68+
sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo PARTITIONED BY (id)")
69+
spark.table("source").write.insertInto(t1)
70+
checkAnswer(spark.table(t1), spark.table("source"))
71+
}
72+
}
73+
74+
test("insertInto: overwrite non-partitioned table") {
75+
val t1 = "testcat.ns1.ns2.tbl"
76+
withTable(t1) {
77+
sql(s"CREATE TABLE $t1 USING foo AS TABLE source")
78+
spark.table("source2").write.mode("overwrite").insertInto(t1)
79+
checkAnswer(spark.table(t1), spark.table("source2"))
80+
}
81+
}
82+
83+
test("insertInto: overwrite - dynamic clause - static mode") {
84+
withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.STATIC.toString) {
85+
val t1 = "testcat.ns1.ns2.tbl"
86+
withTable(t1) {
87+
sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo PARTITIONED BY (id)")
88+
Seq((2L, "dummy"), (4L, "keep")).toDF("id", "data").write.insertInto(t1)
89+
spark.table("source").write.mode("overwrite").insertInto(t1)
90+
checkAnswer(spark.table(t1), spark.table("source"))
91+
}
92+
}
93+
}
94+
95+
test("insertInto: overwrite - dynamic clause - dynamic mode") {
96+
withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) {
97+
val t1 = "testcat.ns1.ns2.tbl"
98+
withTable(t1) {
99+
sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo PARTITIONED BY (id)")
100+
Seq((2L, "dummy"), (4L, "keep")).toDF("id", "data").write.insertInto(t1)
101+
spark.table("source").write.mode("overwrite").insertInto(t1)
102+
checkAnswer(spark.table(t1),
103+
spark.table("source").union(sql("SELECT 4L, 'keep'")))
104+
}
105+
}
106+
}
107+
}

0 commit comments

Comments
 (0)