diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java index facfc0d774e8..ad4fe743218f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java @@ -110,6 +110,26 @@ public interface TableCatalog extends CatalogPlugin { */ Table loadTable(Identifier ident) throws NoSuchTableException; + /** + * Load table metadata by {@link Identifier identifier} from the catalog. Spark will write data + * into this table later. + *
+ * If the catalog supports views and contains a view for the identifier and not a table, this
+ * must throw {@link NoSuchTableException}.
+ *
+ * @param ident a table identifier
+ * @param writePrivileges
+ * @return the table's metadata
+ * @throws NoSuchTableException If the table doesn't exist or is a view
+ *
+ * @since 3.5.3
+ */
+ default Table loadTable(
+ Identifier ident,
+ Set
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableWritePrivilege.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableWritePrivilege.java
new file mode 100644
index 000000000000..ca2d4ba9e7b4
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableWritePrivilege.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.catalog;
+
+/**
+ * The table write privileges that will be provided when loading a table.
+ *
+ * @since 3.5.3
+ */
+public enum TableWritePrivilege {
+ /**
+ * The privilege for adding rows to the table.
+ */
+ INSERT,
+
+ /**
+ * The privilege for changing existing rows in th table.
+ */
+ UPDATE,
+
+ /**
+ * The privilege for deleting rows from the table.
+ */
+ DELETE
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 20546c5c5be3..92ab804f0a70 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1331,8 +1331,12 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
cachedConnectRelation
}.getOrElse(cachedRelation)
}.orElse {
- val table = CatalogV2Util.loadTable(catalog, ident, finalTimeTravelSpec)
- val loaded = createRelation(catalog, ident, table, u.options, u.isStreaming)
+ val writePrivilegesString =
+ Option(u.options.get(UnresolvedRelation.REQUIRED_WRITE_PRIVILEGES))
+ val table = CatalogV2Util.loadTable(
+ catalog, ident, finalTimeTravelSpec, writePrivilegesString)
+ val loaded = createRelation(
+ catalog, ident, table, u.clearWritePrivileges.options, u.isStreaming)
loaded.foreach(AnalysisContext.get.relationCache.update(key, _))
u.getTagValue(LogicalPlan.PLAN_ID_TAG).map { planId =>
loaded.map { loadedRelation =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index eb718eec4256..60d979e9c7af 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{CTERelationDef, LeafNode, Lo
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId
+import org.apache.spark.sql.connector.catalog.TableWritePrivilege
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.types.{DataType, Metadata, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -134,10 +135,36 @@ case class UnresolvedRelation(
override def name: String = tableName
+ def requireWritePrivileges(privileges: Seq[TableWritePrivilege]): UnresolvedRelation = {
+ if (privileges.nonEmpty) {
+ val newOptions = new java.util.HashMap[String, String]
+ newOptions.putAll(options)
+ newOptions.put(UnresolvedRelation.REQUIRED_WRITE_PRIVILEGES, privileges.mkString(","))
+ copy(options = new CaseInsensitiveStringMap(newOptions))
+ } else {
+ this
+ }
+ }
+
+ def clearWritePrivileges: UnresolvedRelation = {
+ if (options.containsKey(UnresolvedRelation.REQUIRED_WRITE_PRIVILEGES)) {
+ val newOptions = new java.util.HashMap[String, String]
+ newOptions.putAll(options)
+ newOptions.remove(UnresolvedRelation.REQUIRED_WRITE_PRIVILEGES)
+ copy(options = new CaseInsensitiveStringMap(newOptions))
+ } else {
+ this
+ }
+ }
+
final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_RELATION)
}
object UnresolvedRelation {
+ // An internal option of `UnresolvedRelation` to specify the required write privileges when
+ // writing data to this relation.
+ val REQUIRED_WRITE_PRIVILEGES = "__required_write_privileges__"
+
def apply(
tableIdentifier: TableIdentifier,
extraOptions: CaseInsensitiveStringMap,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 038f15ee1103..b0922542c562 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -45,7 +45,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.PARAMETER
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils, IntervalUtils}
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ, getZoneId, stringToDate, stringToTimestamp, stringToTimestampWithoutTimeZone}
-import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsNamespaces, TableCatalog}
+import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsNamespaces, TableCatalog, TableWritePrivilege}
import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition
import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, Expression => V2Expression, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform}
import org.apache.spark.sql.errors.{DataTypeErrorsBase, QueryCompilationErrors, QueryParsingErrors, SqlScriptingErrors}
@@ -469,7 +469,7 @@ class AstBuilder extends DataTypeAstBuilder
= visitInsertIntoTable(table)
withIdentClause(relationCtx, Seq(query), (ident, otherPlans) => {
InsertIntoStatement(
- createUnresolvedRelation(relationCtx, ident, options),
+ createUnresolvedRelation(relationCtx, ident, options, Seq(TableWritePrivilege.INSERT)),
partition,
cols,
otherPlans.head,
@@ -482,7 +482,8 @@ class AstBuilder extends DataTypeAstBuilder
= visitInsertOverwriteTable(table)
withIdentClause(relationCtx, Seq(query), (ident, otherPlans) => {
InsertIntoStatement(
- createUnresolvedRelation(relationCtx, ident, options),
+ createUnresolvedRelation(relationCtx, ident, options,
+ Seq(TableWritePrivilege.INSERT, TableWritePrivilege.DELETE)),
partition,
cols,
otherPlans.head,
@@ -491,9 +492,11 @@ class AstBuilder extends DataTypeAstBuilder
byName)
})
case ctx: InsertIntoReplaceWhereContext =>
+ val options = Option(ctx.optionsClause())
withIdentClause(ctx.identifierReference, Seq(query), (ident, otherPlans) => {
OverwriteByExpression.byPosition(
- createUnresolvedRelation(ctx.identifierReference, ident, Option(ctx.optionsClause())),
+ createUnresolvedRelation(ctx.identifierReference, ident, options,
+ Seq(TableWritePrivilege.INSERT, TableWritePrivilege.DELETE)),
otherPlans.head,
expression(ctx.whereClause().booleanExpression()))
})
@@ -578,7 +581,8 @@ class AstBuilder extends DataTypeAstBuilder
override def visitDeleteFromTable(
ctx: DeleteFromTableContext): LogicalPlan = withOrigin(ctx) {
- val table = createUnresolvedRelation(ctx.identifierReference)
+ val table = createUnresolvedRelation(
+ ctx.identifierReference, writePrivileges = Seq(TableWritePrivilege.DELETE))
val tableAlias = getTableAliasWithoutColumnAlias(ctx.tableAlias(), "DELETE")
val aliasedTable = tableAlias.map(SubqueryAlias(_, table)).getOrElse(table)
val predicate = if (ctx.whereClause() != null) {
@@ -590,7 +594,8 @@ class AstBuilder extends DataTypeAstBuilder
}
override def visitUpdateTable(ctx: UpdateTableContext): LogicalPlan = withOrigin(ctx) {
- val table = createUnresolvedRelation(ctx.identifierReference)
+ val table = createUnresolvedRelation(
+ ctx.identifierReference, writePrivileges = Seq(TableWritePrivilege.UPDATE))
val tableAlias = getTableAliasWithoutColumnAlias(ctx.tableAlias(), "UPDATE")
val aliasedTable = tableAlias.map(SubqueryAlias(_, table)).getOrElse(table)
val assignments = withAssignments(ctx.setClause().assignmentList())
@@ -613,9 +618,6 @@ class AstBuilder extends DataTypeAstBuilder
override def visitMergeIntoTable(ctx: MergeIntoTableContext): LogicalPlan = withOrigin(ctx) {
val withSchemaEvolution = ctx.EVOLUTION() != null
- val targetTable = createUnresolvedRelation(ctx.target)
- val targetTableAlias = getTableAliasWithoutColumnAlias(ctx.targetAlias, "MERGE")
- val aliasedTarget = targetTableAlias.map(SubqueryAlias(_, targetTable)).getOrElse(targetTable)
val sourceTableOrQuery = if (ctx.source != null) {
createUnresolvedRelation(ctx.source)
@@ -646,7 +648,7 @@ class AstBuilder extends DataTypeAstBuilder
s"Unrecognized matched action: ${clause.matchedAction().getText}")
}
}
- }
+ }.toSeq
val notMatchedActions = ctx.notMatchedClause().asScala.map {
clause => {
if (clause.notMatchedAction().INSERT() != null) {
@@ -667,7 +669,7 @@ class AstBuilder extends DataTypeAstBuilder
s"Unrecognized matched action: ${clause.notMatchedAction().getText}")
}
}
- }
+ }.toSeq
val notMatchedBySourceActions = ctx.notMatchedBySourceClause().asScala.map {
clause => {
val notMatchedBySourceAction = clause.notMatchedBySourceAction()
@@ -682,7 +684,7 @@ class AstBuilder extends DataTypeAstBuilder
s"Unrecognized matched action: ${clause.notMatchedBySourceAction().getText}")
}
}
- }
+ }.toSeq
if (matchedActions.isEmpty && notMatchedActions.isEmpty && notMatchedBySourceActions.isEmpty) {
throw QueryParsingErrors.mergeStatementWithoutWhenClauseError(ctx)
}
@@ -701,13 +703,19 @@ class AstBuilder extends DataTypeAstBuilder
throw QueryParsingErrors.nonLastNotMatchedBySourceClauseOmitConditionError(ctx)
}
+ val targetTable = createUnresolvedRelation(
+ ctx.target,
+ writePrivileges = MergeIntoTable.getWritePrivileges(
+ matchedActions, notMatchedActions, notMatchedBySourceActions))
+ val targetTableAlias = getTableAliasWithoutColumnAlias(ctx.targetAlias, "MERGE")
+ val aliasedTarget = targetTableAlias.map(SubqueryAlias(_, targetTable)).getOrElse(targetTable)
MergeIntoTable(
aliasedTarget,
aliasedSource,
mergeCondition,
- matchedActions.toSeq,
- notMatchedActions.toSeq,
- notMatchedBySourceActions.toSeq,
+ matchedActions,
+ notMatchedActions,
+ notMatchedBySourceActions,
withSchemaEvolution)
}
@@ -3130,10 +3138,13 @@ class AstBuilder extends DataTypeAstBuilder
*/
private def createUnresolvedRelation(
ctx: IdentifierReferenceContext,
- optionsClause: Option[OptionsClauseContext] = None): LogicalPlan = withOrigin(ctx) {
+ optionsClause: Option[OptionsClauseContext] = None,
+ writePrivileges: Seq[TableWritePrivilege] = Nil): LogicalPlan = withOrigin(ctx) {
val options = resolveOptions(optionsClause)
- withIdentClause(ctx, parts =>
- new UnresolvedRelation(parts, options, isStreaming = false))
+ withIdentClause(ctx, parts => {
+ val relation = new UnresolvedRelation(parts, options, isStreaming = false)
+ relation.requireWritePrivileges(writePrivileges)
+ })
}
/**
@@ -3142,9 +3153,11 @@ class AstBuilder extends DataTypeAstBuilder
private def createUnresolvedRelation(
ctx: ParserRuleContext,
ident: Seq[String],
- optionsClause: Option[OptionsClauseContext]): UnresolvedRelation = withOrigin(ctx) {
+ optionsClause: Option[OptionsClauseContext],
+ writePrivileges: Seq[TableWritePrivilege]): UnresolvedRelation = withOrigin(ctx) {
val options = resolveOptions(optionsClause)
- new UnresolvedRelation(ident, options, isStreaming = false)
+ val relation = new UnresolvedRelation(ident, options, isStreaming = false)
+ relation.requireWritePrivileges(writePrivileges)
}
private def resolveOptions(
@@ -5020,7 +5033,8 @@ class AstBuilder extends DataTypeAstBuilder
if (query.isDefined) {
CacheTableAsSelect(ident.head, query.get, source(ctx.query()), isLazy, options)
} else {
- CacheTable(createUnresolvedRelation(ctx.identifierReference, ident, None),
+ CacheTable(
+ createUnresolvedRelation(ctx.identifierReference, ident, None, writePrivileges = Nil),
ident, isLazy, options)
}
})
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
index 6339a18796fa..05628d7b1c98 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
@@ -794,6 +794,21 @@ case class MergeIntoTable(
copy(targetTable = newLeft, sourceTable = newRight)
}
+object MergeIntoTable {
+ def getWritePrivileges(
+ matchedActions: Seq[MergeAction],
+ notMatchedActions: Seq[MergeAction],
+ notMatchedBySourceActions: Seq[MergeAction]): Seq[TableWritePrivilege] = {
+ val privileges = scala.collection.mutable.HashSet.empty[TableWritePrivilege]
+ (matchedActions.iterator ++ notMatchedActions ++ notMatchedBySourceActions).foreach {
+ case _: DeleteAction => privileges.add(TableWritePrivilege.DELETE)
+ case _: UpdateAction | _: UpdateStarAction => privileges.add(TableWritePrivilege.UPDATE)
+ case _: InsertAction | _: InsertStarAction => privileges.add(TableWritePrivilege.INSERT)
+ }
+ privileges.toSeq
+ }
+}
+
sealed abstract class MergeAction extends Expression with Unevaluable {
def condition: Option[Expression]
override def nullable: Boolean = false
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala
index 283c550c4556..6698f0a02140 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala
@@ -403,9 +403,10 @@ private[sql] object CatalogV2Util {
def loadTable(
catalog: CatalogPlugin,
ident: Identifier,
- timeTravelSpec: Option[TimeTravelSpec] = None): Option[Table] =
+ timeTravelSpec: Option[TimeTravelSpec] = None,
+ writePrivilegesString: Option[String] = None): Option[Table] =
try {
- Option(getTable(catalog, ident, timeTravelSpec))
+ Option(getTable(catalog, ident, timeTravelSpec, writePrivilegesString))
} catch {
case _: NoSuchTableException => None
case _: NoSuchDatabaseException => None
@@ -414,8 +415,10 @@ private[sql] object CatalogV2Util {
def getTable(
catalog: CatalogPlugin,
ident: Identifier,
- timeTravelSpec: Option[TimeTravelSpec] = None): Table = {
+ timeTravelSpec: Option[TimeTravelSpec] = None,
+ writePrivilegesString: Option[String] = None): Table = {
if (timeTravelSpec.nonEmpty) {
+ assert(writePrivilegesString.isEmpty, "Should not write to a table with time travel")
timeTravelSpec.get match {
case v: AsOfVersion =>
catalog.asTableCatalog.loadTable(ident, v.version)
@@ -423,7 +426,13 @@ private[sql] object CatalogV2Util {
catalog.asTableCatalog.loadTable(ident, ts.timestamp)
}
} else {
- catalog.asTableCatalog.loadTable(ident)
+ if (writePrivilegesString.isDefined) {
+ val writePrivileges = writePrivilegesString.get.split(",").map(_.trim)
+ .map(TableWritePrivilege.valueOf).toSet.asJava
+ catalog.asTableCatalog.loadTable(ident, writePrivileges)
+ } else {
+ catalog.asTableCatalog.loadTable(ident)
+ }
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala
index c930292f2793..756ec95c70d2 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala
@@ -39,7 +39,16 @@ class DDLParserSuite extends AnalysisTest {
}
private def parseCompare(sql: String, expected: LogicalPlan): Unit = {
- comparePlans(parsePlan(sql), expected, checkAnalysis = false)
+ // We don't care the write privileges in this suite.
+ val parsed = parsePlan(sql).transform {
+ case u: UnresolvedRelation => u.clearWritePrivileges
+ case i: InsertIntoStatement =>
+ i.table match {
+ case u: UnresolvedRelation => i.copy(table = u.clearWritePrivileges)
+ case _ => i
+ }
+ }
+ comparePlans(parsed, expected, checkAnalysis = false)
}
private def internalException(sqlText: String): SparkThrowable = {
@@ -2635,20 +2644,20 @@ class DDLParserSuite extends AnalysisTest {
withSQLConf(
SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED.key ->
optimizeInsertIntoValues.toString) {
- comparePlans(parsePlan(dateTypeSql), insertPartitionPlan(
+ parseCompare(dateTypeSql, insertPartitionPlan(
"2019-01-02", optimizeInsertIntoValues))
withSQLConf(SQLConf.LEGACY_INTERVAL_ENABLED.key -> "true") {
- comparePlans(parsePlan(intervalTypeSql), insertPartitionPlan(
+ parseCompare(intervalTypeSql, insertPartitionPlan(
interval, optimizeInsertIntoValues))
}
- comparePlans(parsePlan(ymIntervalTypeSql), insertPartitionPlan(
+ parseCompare(ymIntervalTypeSql, insertPartitionPlan(
"INTERVAL '1-2' YEAR TO MONTH", optimizeInsertIntoValues))
- comparePlans(parsePlan(dtIntervalTypeSql),
+ parseCompare(dtIntervalTypeSql,
insertPartitionPlan(
"INTERVAL '1 02:03:04.128462' DAY TO SECOND", optimizeInsertIntoValues))
- comparePlans(parsePlan(timestampTypeSql), insertPartitionPlan(
+ parseCompare(timestampTypeSql, insertPartitionPlan(
timestamp, optimizeInsertIntoValues))
- comparePlans(parsePlan(binaryTypeSql), insertPartitionPlan(
+ parseCompare(binaryTypeSql, insertPartitionPlan(
binaryStr, optimizeInsertIntoValues))
}
}
@@ -2748,12 +2757,12 @@ class DDLParserSuite extends AnalysisTest {
// In each of the following cases, the DEFAULT reference parses as an unresolved attribute
// reference. We can handle these cases after the parsing stage, at later phases of analysis.
- comparePlans(parsePlan("VALUES (1, 2, DEFAULT) AS val"),
+ parseCompare("VALUES (1, 2, DEFAULT) AS val",
SubqueryAlias("val",
UnresolvedInlineTable(Seq("col1", "col2", "col3"), Seq(Seq(Literal(1), Literal(2),
UnresolvedAttribute("DEFAULT"))))))
- comparePlans(parsePlan(
- "INSERT INTO t PARTITION(part = date'2019-01-02') VALUES ('a', DEFAULT)"),
+ parseCompare(
+ "INSERT INTO t PARTITION(part = date'2019-01-02') VALUES ('a', DEFAULT)",
InsertIntoStatement(
UnresolvedRelation(Seq("t")),
Map("part" -> Some("2019-01-02")),
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
index e0217a5637a8..a6a32e87b742 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -40,7 +40,16 @@ class PlanParserSuite extends AnalysisTest {
import org.apache.spark.sql.catalyst.dsl.plans._
private def assertEqual(sqlCommand: String, plan: LogicalPlan): Unit = {
- comparePlans(parsePlan(sqlCommand), plan, checkAnalysis = false)
+ // We don't care the write privileges in this suite.
+ val parsed = parsePlan(sqlCommand).transform {
+ case u: UnresolvedRelation => u.clearWritePrivileges
+ case i: InsertIntoStatement =>
+ i.table match {
+ case u: UnresolvedRelation => i.copy(table = u.clearWritePrivileges)
+ case _ => i
+ }
+ }
+ comparePlans(parsed, plan, checkAnalysis = false)
}
private def parseException(sqlText: String): SparkThrowable = {
@@ -1048,57 +1057,56 @@ class PlanParserSuite extends AnalysisTest {
errorClass = "PARSE_SYNTAX_ERROR",
parameters = Map("error" -> "'b'", "hint" -> ""))
- comparePlans(
- parsePlan("SELECT /*+ HINT */ * FROM t"),
+ assertEqual(
+ "SELECT /*+ HINT */ * FROM t",
UnresolvedHint("HINT", Seq.empty, table("t").select(star())))
- comparePlans(
- parsePlan("SELECT /*+ BROADCASTJOIN(u) */ * FROM t"),
+ assertEqual(
+ "SELECT /*+ BROADCASTJOIN(u) */ * FROM t",
UnresolvedHint("BROADCASTJOIN", Seq($"u"), table("t").select(star())))
- comparePlans(
- parsePlan("SELECT /*+ MAPJOIN(u) */ * FROM t"),
+ assertEqual(
+ "SELECT /*+ MAPJOIN(u) */ * FROM t",
UnresolvedHint("MAPJOIN", Seq($"u"), table("t").select(star())))
- comparePlans(
- parsePlan("SELECT /*+ STREAMTABLE(a,b,c) */ * FROM t"),
+ assertEqual(
+ "SELECT /*+ STREAMTABLE(a,b,c) */ * FROM t",
UnresolvedHint("STREAMTABLE", Seq($"a", $"b", $"c"), table("t").select(star())))
- comparePlans(
- parsePlan("SELECT /*+ INDEX(t, emp_job_ix) */ * FROM t"),
+ assertEqual(
+ "SELECT /*+ INDEX(t, emp_job_ix) */ * FROM t",
UnresolvedHint("INDEX", Seq($"t", $"emp_job_ix"), table("t").select(star())))
- comparePlans(
- parsePlan("SELECT /*+ MAPJOIN(`default.t`) */ * from `default.t`"),
+ assertEqual(
+ "SELECT /*+ MAPJOIN(`default.t`) */ * from `default.t`",
UnresolvedHint("MAPJOIN", Seq(UnresolvedAttribute.quoted("default.t")),
table("default.t").select(star())))
- comparePlans(
- parsePlan("SELECT /*+ MAPJOIN(t) */ a from t where true group by a order by a"),
+ assertEqual(
+ "SELECT /*+ MAPJOIN(t) */ a from t where true group by a order by a",
UnresolvedHint("MAPJOIN", Seq($"t"),
table("t").where(Literal(true)).groupBy($"a")($"a")).orderBy($"a".asc))
- comparePlans(
- parsePlan("SELECT /*+ COALESCE(10) */ * FROM t"),
+ assertEqual(
+ "SELECT /*+ COALESCE(10) */ * FROM t",
UnresolvedHint("COALESCE", Seq(Literal(10)),
table("t").select(star())))
- comparePlans(
- parsePlan("SELECT /*+ REPARTITION(100) */ * FROM t"),
+ assertEqual(
+ "SELECT /*+ REPARTITION(100) */ * FROM t",
UnresolvedHint("REPARTITION", Seq(Literal(100)),
table("t").select(star())))
- comparePlans(
- parsePlan(
- "INSERT INTO s SELECT /*+ REPARTITION(100), COALESCE(500), COALESCE(10) */ * FROM t"),
+ assertEqual(
+ "INSERT INTO s SELECT /*+ REPARTITION(100), COALESCE(500), COALESCE(10) */ * FROM t",
InsertIntoStatement(table("s"), Map.empty, Nil,
UnresolvedHint("REPARTITION", Seq(Literal(100)),
UnresolvedHint("COALESCE", Seq(Literal(500)),
UnresolvedHint("COALESCE", Seq(Literal(10)),
table("t").select(star())))), overwrite = false, ifPartitionNotExists = false))
- comparePlans(
- parsePlan("SELECT /*+ BROADCASTJOIN(u), REPARTITION(100) */ * FROM t"),
+ assertEqual(
+ "SELECT /*+ BROADCASTJOIN(u), REPARTITION(100) */ * FROM t",
UnresolvedHint("BROADCASTJOIN", Seq($"u"),
UnresolvedHint("REPARTITION", Seq(Literal(100)),
table("t").select(star()))))
@@ -1109,49 +1117,48 @@ class PlanParserSuite extends AnalysisTest {
errorClass = "PARSE_SYNTAX_ERROR",
parameters = Map("error" -> "'+'", "hint" -> ""))
- comparePlans(
- parsePlan("SELECT /*+ REPARTITION(c) */ * FROM t"),
+ assertEqual(
+ "SELECT /*+ REPARTITION(c) */ * FROM t",
UnresolvedHint("REPARTITION", Seq(UnresolvedAttribute("c")),
table("t").select(star())))
- comparePlans(
- parsePlan("SELECT /*+ REPARTITION(100, c) */ * FROM t"),
+ assertEqual(
+ "SELECT /*+ REPARTITION(100, c) */ * FROM t",
UnresolvedHint("REPARTITION", Seq(Literal(100), UnresolvedAttribute("c")),
table("t").select(star())))
- comparePlans(
- parsePlan("SELECT /*+ REPARTITION(100, c), COALESCE(50) */ * FROM t"),
+ assertEqual(
+ "SELECT /*+ REPARTITION(100, c), COALESCE(50) */ * FROM t",
UnresolvedHint("REPARTITION", Seq(Literal(100), UnresolvedAttribute("c")),
UnresolvedHint("COALESCE", Seq(Literal(50)),
table("t").select(star()))))
- comparePlans(
- parsePlan("SELECT /*+ REPARTITION(100, c), BROADCASTJOIN(u), COALESCE(50) */ * FROM t"),
+ assertEqual(
+ "SELECT /*+ REPARTITION(100, c), BROADCASTJOIN(u), COALESCE(50) */ * FROM t",
UnresolvedHint("REPARTITION", Seq(Literal(100), UnresolvedAttribute("c")),
UnresolvedHint("BROADCASTJOIN", Seq($"u"),
UnresolvedHint("COALESCE", Seq(Literal(50)),
table("t").select(star())))))
- comparePlans(
- parsePlan(
- """
- |SELECT
- |/*+ REPARTITION(100, c), BROADCASTJOIN(u), COALESCE(50), REPARTITION(300, c) */
- |* FROM t
- """.stripMargin),
+ assertEqual(
+ """
+ |SELECT
+ |/*+ REPARTITION(100, c), BROADCASTJOIN(u), COALESCE(50), REPARTITION(300, c) */
+ |* FROM t
+ """.stripMargin,
UnresolvedHint("REPARTITION", Seq(Literal(100), UnresolvedAttribute("c")),
UnresolvedHint("BROADCASTJOIN", Seq($"u"),
UnresolvedHint("COALESCE", Seq(Literal(50)),
UnresolvedHint("REPARTITION", Seq(Literal(300), UnresolvedAttribute("c")),
table("t").select(star()))))))
- comparePlans(
- parsePlan("SELECT /*+ REPARTITION_BY_RANGE(c) */ * FROM t"),
+ assertEqual(
+ "SELECT /*+ REPARTITION_BY_RANGE(c) */ * FROM t",
UnresolvedHint("REPARTITION_BY_RANGE", Seq(UnresolvedAttribute("c")),
table("t").select(star())))
- comparePlans(
- parsePlan("SELECT /*+ REPARTITION_BY_RANGE(100, c) */ * FROM t"),
+ assertEqual(
+ "SELECT /*+ REPARTITION_BY_RANGE(100, c) */ * FROM t",
UnresolvedHint("REPARTITION_BY_RANGE", Seq(Literal(100), UnresolvedAttribute("c")),
table("t").select(star())))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index 991487170f17..60734efbf5bb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -30,12 +30,15 @@ import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSel
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.connector.catalog.{CatalogPlugin, CatalogV2Implicits, CatalogV2Util, Identifier, SupportsCatalogOptions, Table, TableCatalog, TableProvider, V1Table}
import org.apache.spark.sql.connector.catalog.TableCapability._
+import org.apache.spark.sql.connector.catalog.TableWritePrivilege
+import org.apache.spark.sql.connector.catalog.TableWritePrivilege._
import org.apache.spark.sql.connector.expressions.{ClusterByTransform, FieldReference, IdentityTransform, Transform}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, DataSourceUtils, LogicalRelation}
import org.apache.spark.sql.execution.datasources.v2._
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types.StructType
@@ -473,7 +476,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
private def insertInto(catalog: CatalogPlugin, ident: Identifier): Unit = {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
- val table = catalog.asTableCatalog.loadTable(ident) match {
+ val table = catalog.asTableCatalog.loadTable(ident, getWritePrivileges.toSet.asJava) match {
case _: V1Table =>
return insertInto(TableIdentifier(ident.name(), ident.namespace().headOption))
case t =>
@@ -504,7 +507,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
private def insertInto(tableIdent: TableIdentifier): Unit = {
runCommand(df.sparkSession) {
InsertIntoStatement(
- table = UnresolvedRelation(tableIdent),
+ table = UnresolvedRelation(tableIdent).requireWritePrivileges(getWritePrivileges),
partitionSpec = Map.empty[String, Option[String]],
Nil,
query = df.logicalPlan,
@@ -513,6 +516,11 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
}
}
+ private def getWritePrivileges: Seq[TableWritePrivilege] = mode match {
+ case SaveMode.Overwrite => Seq(INSERT, DELETE)
+ case _ => Seq(INSERT)
+ }
+
private def getBucketSpec: Option[BucketSpec] = {
if (sortColumnNames.isDefined && numBuckets.isEmpty) {
throw QueryCompilationErrors.sortByWithoutBucketingError()
@@ -588,7 +596,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
val session = df.sparkSession
- val canUseV2 = lookupV2Provider().isDefined
+ val canUseV2 = lookupV2Provider().isDefined ||
+ df.sparkSession.sessionState.conf.getConf(SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION).isDefined
session.sessionState.sqlParser.parseMultipartIdentifier(tableName) match {
case nameParts @ NonSessionCatalogAndIdentifier(catalog, ident) =>
@@ -609,7 +618,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
private def saveAsTable(
catalog: TableCatalog, ident: Identifier, nameParts: Seq[String]): Unit = {
- val tableOpt = try Option(catalog.loadTable(ident)) catch {
+ val tableOpt = try Option(catalog.loadTable(ident, getWritePrivileges.toSet.asJava)) catch {
case _: NoSuchTableException => None
}
@@ -670,7 +679,6 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
val catalog = df.sparkSession.sessionState.catalog
val qualifiedIdent = catalog.qualifyIdentifier(tableIdent)
val tableExists = catalog.tableExists(qualifiedIdent)
- val tableName = qualifiedIdent.unquotedString
(tableExists, mode) match {
case (true, SaveMode.Ignore) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala
index d9ad0003a525..576d8276b56e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala
@@ -24,6 +24,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException, UnresolvedFunction, UnresolvedIdentifier, UnresolvedRelation}
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Literal}
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, LogicalPlan, OptionList, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect, UnresolvedTableSpec}
+import org.apache.spark.sql.connector.catalog.TableWritePrivilege._
import org.apache.spark.sql.connector.expressions.{ClusterByTransform, FieldReference, LogicalExpressions, NamedReference, Transform}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.QueryExecution
@@ -169,7 +170,9 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T])
*/
@throws(classOf[NoSuchTableException])
def append(): Unit = {
- val append = AppendData.byName(UnresolvedRelation(tableName), logicalPlan, options.toMap)
+ val append = AppendData.byName(
+ UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT)),
+ logicalPlan, options.toMap)
runCommand(append)
}
@@ -186,7 +189,8 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T])
@throws(classOf[NoSuchTableException])
def overwrite(condition: Column): Unit = {
val overwrite = OverwriteByExpression.byName(
- UnresolvedRelation(tableName), logicalPlan, expression(condition), options.toMap)
+ UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT, DELETE)),
+ logicalPlan, expression(condition), options.toMap)
runCommand(overwrite)
}
@@ -206,7 +210,8 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T])
@throws(classOf[NoSuchTableException])
def overwritePartitions(): Unit = {
val dynamicOverwrite = OverwritePartitionsDynamic.byName(
- UnresolvedRelation(tableName), logicalPlan, options.toMap)
+ UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT, DELETE)),
+ logicalPlan, options.toMap)
runCommand(dynamicOverwrite)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala
index d8042720577d..6212a7fdb259 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala
@@ -189,7 +189,8 @@ class MergeIntoWriter[T] private[sql] (
}
val merge = MergeIntoTable(
- UnresolvedRelation(tableName),
+ UnresolvedRelation(tableName).requireWritePrivileges(MergeIntoTable.getWritePrivileges(
+ matchedActions, notMatchedActions, notMatchedBySourceActions)),
logicalPlan,
on.expr,
matchedActions,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala
index 9053fb9cc73f..20e3b4e980f2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala
@@ -683,7 +683,7 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager)
}
private def supportsV1Command(catalog: CatalogPlugin): Boolean = {
- catalog.name().equalsIgnoreCase(CatalogManager.SESSION_CATALOG_NAME) &&
- !SQLConf.get.getConf(SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION).isDefined
+ isSessionCatalog(catalog) &&
+ SQLConf.get.getConf(SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION).isEmpty
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
index 5632595de7cf..89372017257d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, TableSpec, UnaryNode}
import org.apache.spark.sql.catalyst.util.{removeInternalMetadata, CharVarcharUtils, WriteDeltaProjections}
import org.apache.spark.sql.catalyst.util.RowDeltaUtils.{DELETE_OPERATION, INSERT_OPERATION, UPDATE_OPERATION}
-import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagedTable, StagingTableCatalog, Table, TableCatalog}
+import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagedTable, StagingTableCatalog, Table, TableCatalog, TableWritePrivilege}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.connector.metric.CustomMetric
import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, DeltaWrite, DeltaWriter, PhysicalWriteInfoImpl, Write, WriterCommitMessage}
@@ -84,7 +84,8 @@ case class CreateTableAsSelectExec(
}
val table = Option(catalog.createTable(
ident, getV2Columns(query.schema, catalog.useNullableQuerySchema),
- partitioning.toArray, properties.asJava)).getOrElse(catalog.loadTable(ident))
+ partitioning.toArray, properties.asJava)
+ ).getOrElse(catalog.loadTable(ident, Set(TableWritePrivilege.INSERT).asJava))
writeToTable(catalog, table, writeOptions, ident, query)
}
}
@@ -164,7 +165,8 @@ case class ReplaceTableAsSelectExec(
}
val table = Option(catalog.createTable(
ident, getV2Columns(query.schema, catalog.useNullableQuerySchema),
- partitioning.toArray, properties.asJava)).getOrElse(catalog.loadTable(ident))
+ partitioning.toArray, properties.asJava)
+ ).getOrElse(catalog.loadTable(ident, Set(TableWritePrivilege.INSERT).asJava))
writeToTable(catalog, table, writeOptions, ident, query)
}
}
diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/explain-aqe.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/explain-aqe.sql.out
index 3aea86b232cb..f9a282c2b927 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/explain-aqe.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/explain-aqe.sql.out
@@ -196,7 +196,7 @@ ExplainCommand 'Aggregate ['key], ['key, unresolvedalias('MIN('val))], Formatted
-- !query
EXPLAIN EXTENDED INSERT INTO TABLE explain_temp5 SELECT * FROM explain_temp4
-- !query analysis
-ExplainCommand 'InsertIntoStatement 'UnresolvedRelation [explain_temp5], [], false, false, false, false, ExtendedMode
+ExplainCommand 'InsertIntoStatement 'UnresolvedRelation [explain_temp5], [__required_write_privileges__=INSERT], false, false, false, false, ExtendedMode
-- !query
diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/explain.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/explain.sql.out
index 3aea86b232cb..f9a282c2b927 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/explain.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/explain.sql.out
@@ -196,7 +196,7 @@ ExplainCommand 'Aggregate ['key], ['key, unresolvedalias('MIN('val))], Formatted
-- !query
EXPLAIN EXTENDED INSERT INTO TABLE explain_temp5 SELECT * FROM explain_temp4
-- !query analysis
-ExplainCommand 'InsertIntoStatement 'UnresolvedRelation [explain_temp5], [], false, false, false, false, ExtendedMode
+ExplainCommand 'InsertIntoStatement 'UnresolvedRelation [explain_temp5], [__required_write_privileges__=INSERT], false, false, false, false, ExtendedMode
-- !query
diff --git a/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out b/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out
index 3830b47ba8a6..16077a78f389 100644
--- a/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out
@@ -1139,7 +1139,7 @@ EXPLAIN EXTENDED INSERT INTO TABLE explain_temp5 SELECT * FROM explain_temp4
struct