Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,26 @@ public interface TableCatalog extends CatalogPlugin {
*/
Table loadTable(Identifier ident) throws NoSuchTableException;

/**
* Load table metadata by {@link Identifier identifier} from the catalog. Spark will write data
* into this table later.
* <p>
* If the catalog supports views and contains a view for the identifier and not a table, this
* must throw {@link NoSuchTableException}.
*
* @param ident a table identifier
* @param writePrivileges
* @return the table's metadata
* @throws NoSuchTableException If the table doesn't exist or is a view
*
* @since 3.5.3
*/
default Table loadTable(
Identifier ident,
Set<TableWritePrivilege> writePrivileges) throws NoSuchTableException {
return loadTable(ident);
}

/**
* Load table metadata of a specific version by {@link Identifier identifier} from the catalog.
* <p>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector.catalog;

/**
* The table write privileges that will be provided when loading a table.
*
* @since 3.5.3
*/
public enum TableWritePrivilege {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only include write privileges as the full privileges include ALTER, REFERENCE, etc, which is not what we need for loadTable.

/**
* The privilege for adding rows to the table.
*/
INSERT,

/**
* The privilege for changing existing rows in th table.
*/
UPDATE,

/**
* The privilege for deleting rows from the table.
*/
DELETE
}
Original file line number Diff line number Diff line change
Expand Up @@ -1331,8 +1331,12 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
cachedConnectRelation
}.getOrElse(cachedRelation)
}.orElse {
val table = CatalogV2Util.loadTable(catalog, ident, finalTimeTravelSpec)
val loaded = createRelation(catalog, ident, table, u.options, u.isStreaming)
val writePrivilegesString =
Option(u.options.get(UnresolvedRelation.REQUIRED_WRITE_PRIVILEGES))
val table = CatalogV2Util.loadTable(
catalog, ident, finalTimeTravelSpec, writePrivilegesString)
val loaded = createRelation(
catalog, ident, table, u.clearWritePrivileges.options, u.isStreaming)
loaded.foreach(AnalysisContext.get.relationCache.update(key, _))
u.getTagValue(LogicalPlan.PLAN_ID_TAG).map { planId =>
loaded.map { loadedRelation =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{CTERelationDef, LeafNode, Lo
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId
import org.apache.spark.sql.connector.catalog.TableWritePrivilege
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.types.{DataType, Metadata, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
Expand Down Expand Up @@ -134,10 +135,36 @@ case class UnresolvedRelation(

override def name: String = tableName

def requireWritePrivileges(privileges: Seq[TableWritePrivilege]): UnresolvedRelation = {
if (privileges.nonEmpty) {
val newOptions = new java.util.HashMap[String, String]
newOptions.putAll(options)
newOptions.put(UnresolvedRelation.REQUIRED_WRITE_PRIVILEGES, privileges.mkString(","))
copy(options = new CaseInsensitiveStringMap(newOptions))
} else {
this
}
}

def clearWritePrivileges: UnresolvedRelation = {
if (options.containsKey(UnresolvedRelation.REQUIRED_WRITE_PRIVILEGES)) {
val newOptions = new java.util.HashMap[String, String]
newOptions.putAll(options)
newOptions.remove(UnresolvedRelation.REQUIRED_WRITE_PRIVILEGES)
copy(options = new CaseInsensitiveStringMap(newOptions))
} else {
this
}
}

final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_RELATION)
}

object UnresolvedRelation {
// An internal option of `UnresolvedRelation` to specify the required write privileges when
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can add a new field to UnresolvedRelation but it may break third-party catalyst rules.

// writing data to this relation.
val REQUIRED_WRITE_PRIVILEGES = "__required_write_privileges__"

def apply(
tableIdentifier: TableIdentifier,
extraOptions: CaseInsensitiveStringMap,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.PARAMETER
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils, IntervalUtils}
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ, getZoneId, stringToDate, stringToTimestamp, stringToTimestampWithoutTimeZone}
import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsNamespaces, TableCatalog}
import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsNamespaces, TableCatalog, TableWritePrivilege}
import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition
import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, Expression => V2Expression, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform}
import org.apache.spark.sql.errors.{DataTypeErrorsBase, QueryCompilationErrors, QueryParsingErrors, SqlScriptingErrors}
Expand Down Expand Up @@ -469,7 +469,7 @@ class AstBuilder extends DataTypeAstBuilder
= visitInsertIntoTable(table)
withIdentClause(relationCtx, Seq(query), (ident, otherPlans) => {
InsertIntoStatement(
createUnresolvedRelation(relationCtx, ident, options),
createUnresolvedRelation(relationCtx, ident, options, Seq(TableWritePrivilege.INSERT)),
partition,
cols,
otherPlans.head,
Expand All @@ -482,7 +482,8 @@ class AstBuilder extends DataTypeAstBuilder
= visitInsertOverwriteTable(table)
withIdentClause(relationCtx, Seq(query), (ident, otherPlans) => {
InsertIntoStatement(
createUnresolvedRelation(relationCtx, ident, options),
createUnresolvedRelation(relationCtx, ident, options,
Seq(TableWritePrivilege.INSERT, TableWritePrivilege.DELETE)),
partition,
cols,
otherPlans.head,
Expand All @@ -491,9 +492,11 @@ class AstBuilder extends DataTypeAstBuilder
byName)
})
case ctx: InsertIntoReplaceWhereContext =>
val options = Option(ctx.optionsClause())
withIdentClause(ctx.identifierReference, Seq(query), (ident, otherPlans) => {
OverwriteByExpression.byPosition(
createUnresolvedRelation(ctx.identifierReference, ident, Option(ctx.optionsClause())),
createUnresolvedRelation(ctx.identifierReference, ident, options,
Seq(TableWritePrivilege.INSERT, TableWritePrivilege.DELETE)),
otherPlans.head,
expression(ctx.whereClause().booleanExpression()))
})
Expand Down Expand Up @@ -578,7 +581,8 @@ class AstBuilder extends DataTypeAstBuilder

override def visitDeleteFromTable(
ctx: DeleteFromTableContext): LogicalPlan = withOrigin(ctx) {
val table = createUnresolvedRelation(ctx.identifierReference)
val table = createUnresolvedRelation(
ctx.identifierReference, writePrivileges = Seq(TableWritePrivilege.DELETE))
val tableAlias = getTableAliasWithoutColumnAlias(ctx.tableAlias(), "DELETE")
val aliasedTable = tableAlias.map(SubqueryAlias(_, table)).getOrElse(table)
val predicate = if (ctx.whereClause() != null) {
Expand All @@ -590,7 +594,8 @@ class AstBuilder extends DataTypeAstBuilder
}

override def visitUpdateTable(ctx: UpdateTableContext): LogicalPlan = withOrigin(ctx) {
val table = createUnresolvedRelation(ctx.identifierReference)
val table = createUnresolvedRelation(
ctx.identifierReference, writePrivileges = Seq(TableWritePrivilege.UPDATE))
val tableAlias = getTableAliasWithoutColumnAlias(ctx.tableAlias(), "UPDATE")
val aliasedTable = tableAlias.map(SubqueryAlias(_, table)).getOrElse(table)
val assignments = withAssignments(ctx.setClause().assignmentList())
Expand All @@ -613,9 +618,6 @@ class AstBuilder extends DataTypeAstBuilder

override def visitMergeIntoTable(ctx: MergeIntoTableContext): LogicalPlan = withOrigin(ctx) {
val withSchemaEvolution = ctx.EVOLUTION() != null
val targetTable = createUnresolvedRelation(ctx.target)
val targetTableAlias = getTableAliasWithoutColumnAlias(ctx.targetAlias, "MERGE")
val aliasedTarget = targetTableAlias.map(SubqueryAlias(_, targetTable)).getOrElse(targetTable)

val sourceTableOrQuery = if (ctx.source != null) {
createUnresolvedRelation(ctx.source)
Expand Down Expand Up @@ -646,7 +648,7 @@ class AstBuilder extends DataTypeAstBuilder
s"Unrecognized matched action: ${clause.matchedAction().getText}")
}
}
}
}.toSeq
val notMatchedActions = ctx.notMatchedClause().asScala.map {
clause => {
if (clause.notMatchedAction().INSERT() != null) {
Expand All @@ -667,7 +669,7 @@ class AstBuilder extends DataTypeAstBuilder
s"Unrecognized matched action: ${clause.notMatchedAction().getText}")
}
}
}
}.toSeq
val notMatchedBySourceActions = ctx.notMatchedBySourceClause().asScala.map {
clause => {
val notMatchedBySourceAction = clause.notMatchedBySourceAction()
Expand All @@ -682,7 +684,7 @@ class AstBuilder extends DataTypeAstBuilder
s"Unrecognized matched action: ${clause.notMatchedBySourceAction().getText}")
}
}
}
}.toSeq
if (matchedActions.isEmpty && notMatchedActions.isEmpty && notMatchedBySourceActions.isEmpty) {
throw QueryParsingErrors.mergeStatementWithoutWhenClauseError(ctx)
}
Expand All @@ -701,13 +703,19 @@ class AstBuilder extends DataTypeAstBuilder
throw QueryParsingErrors.nonLastNotMatchedBySourceClauseOmitConditionError(ctx)
}

val targetTable = createUnresolvedRelation(
ctx.target,
writePrivileges = MergeIntoTable.getWritePrivileges(
matchedActions, notMatchedActions, notMatchedBySourceActions))
val targetTableAlias = getTableAliasWithoutColumnAlias(ctx.targetAlias, "MERGE")
val aliasedTarget = targetTableAlias.map(SubqueryAlias(_, targetTable)).getOrElse(targetTable)
MergeIntoTable(
aliasedTarget,
aliasedSource,
mergeCondition,
matchedActions.toSeq,
notMatchedActions.toSeq,
notMatchedBySourceActions.toSeq,
matchedActions,
notMatchedActions,
notMatchedBySourceActions,
withSchemaEvolution)
}

Expand Down Expand Up @@ -3130,10 +3138,13 @@ class AstBuilder extends DataTypeAstBuilder
*/
private def createUnresolvedRelation(
ctx: IdentifierReferenceContext,
optionsClause: Option[OptionsClauseContext] = None): LogicalPlan = withOrigin(ctx) {
optionsClause: Option[OptionsClauseContext] = None,
writePrivileges: Seq[TableWritePrivilege] = Nil): LogicalPlan = withOrigin(ctx) {
val options = resolveOptions(optionsClause)
withIdentClause(ctx, parts =>
new UnresolvedRelation(parts, options, isStreaming = false))
withIdentClause(ctx, parts => {
val relation = new UnresolvedRelation(parts, options, isStreaming = false)
relation.requireWritePrivileges(writePrivileges)
})
}

/**
Expand All @@ -3142,9 +3153,11 @@ class AstBuilder extends DataTypeAstBuilder
private def createUnresolvedRelation(
ctx: ParserRuleContext,
ident: Seq[String],
optionsClause: Option[OptionsClauseContext]): UnresolvedRelation = withOrigin(ctx) {
optionsClause: Option[OptionsClauseContext],
writePrivileges: Seq[TableWritePrivilege]): UnresolvedRelation = withOrigin(ctx) {
val options = resolveOptions(optionsClause)
new UnresolvedRelation(ident, options, isStreaming = false)
val relation = new UnresolvedRelation(ident, options, isStreaming = false)
relation.requireWritePrivileges(writePrivileges)
}

private def resolveOptions(
Expand Down Expand Up @@ -5020,7 +5033,8 @@ class AstBuilder extends DataTypeAstBuilder
if (query.isDefined) {
CacheTableAsSelect(ident.head, query.get, source(ctx.query()), isLazy, options)
} else {
CacheTable(createUnresolvedRelation(ctx.identifierReference, ident, None),
CacheTable(
createUnresolvedRelation(ctx.identifierReference, ident, None, writePrivileges = Nil),
ident, isLazy, options)
}
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,21 @@ case class MergeIntoTable(
copy(targetTable = newLeft, sourceTable = newRight)
}

object MergeIntoTable {
def getWritePrivileges(
matchedActions: Seq[MergeAction],
notMatchedActions: Seq[MergeAction],
notMatchedBySourceActions: Seq[MergeAction]): Seq[TableWritePrivilege] = {
val privileges = scala.collection.mutable.HashSet.empty[TableWritePrivilege]
(matchedActions.iterator ++ notMatchedActions ++ notMatchedBySourceActions).foreach {
case _: DeleteAction => privileges.add(TableWritePrivilege.DELETE)
case _: UpdateAction | _: UpdateStarAction => privileges.add(TableWritePrivilege.UPDATE)
case _: InsertAction | _: InsertStarAction => privileges.add(TableWritePrivilege.INSERT)
}
privileges.toSeq
}
}

sealed abstract class MergeAction extends Expression with Unevaluable {
def condition: Option[Expression]
override def nullable: Boolean = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,9 +403,10 @@ private[sql] object CatalogV2Util {
def loadTable(
catalog: CatalogPlugin,
ident: Identifier,
timeTravelSpec: Option[TimeTravelSpec] = None): Option[Table] =
timeTravelSpec: Option[TimeTravelSpec] = None,
writePrivilegesString: Option[String] = None): Option[Table] =
try {
Option(getTable(catalog, ident, timeTravelSpec))
Option(getTable(catalog, ident, timeTravelSpec, writePrivilegesString))
} catch {
case _: NoSuchTableException => None
case _: NoSuchDatabaseException => None
Expand All @@ -414,16 +415,24 @@ private[sql] object CatalogV2Util {
def getTable(
catalog: CatalogPlugin,
ident: Identifier,
timeTravelSpec: Option[TimeTravelSpec] = None): Table = {
timeTravelSpec: Option[TimeTravelSpec] = None,
writePrivilegesString: Option[String] = None): Table = {
if (timeTravelSpec.nonEmpty) {
assert(writePrivilegesString.isEmpty, "Should not write to a table with time travel")
timeTravelSpec.get match {
case v: AsOfVersion =>
catalog.asTableCatalog.loadTable(ident, v.version)
case ts: AsOfTimestamp =>
catalog.asTableCatalog.loadTable(ident, ts.timestamp)
}
} else {
catalog.asTableCatalog.loadTable(ident)
if (writePrivilegesString.isDefined) {
val writePrivileges = writePrivilegesString.get.split(",").map(_.trim)
.map(TableWritePrivilege.valueOf).toSet.asJava
catalog.asTableCatalog.loadTable(ident, writePrivileges)
} else {
catalog.asTableCatalog.loadTable(ident)
}
}
}

Expand Down
Loading