Skip to content

Commit 0f4d9aa

Browse files
committed
[SPARK-26946][SQL][FOLLOWUP] Require lookup function
(cherry picked from commit 4e99de0)
1 parent bb4514c commit 0f4d9aa

5 files changed

Lines changed: 105 additions & 123 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/LookupCatalog.scala

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,35 +26,31 @@ import org.apache.spark.sql.catalyst.TableIdentifier
2626
@Experimental
2727
trait LookupCatalog {
2828

29-
def lookupCatalog: Option[(String) => CatalogPlugin] = None
29+
protected def lookupCatalog(name: String): CatalogPlugin
3030

3131
type CatalogObjectIdentifier = (Option[CatalogPlugin], Identifier)
3232

3333
/**
3434
* Extract catalog plugin and identifier from a multi-part identifier.
3535
*/
3636
object CatalogObjectIdentifier {
37-
def unapply(parts: Seq[String]): Option[CatalogObjectIdentifier] = lookupCatalog.map { lookup =>
38-
parts match {
39-
case Seq(name) =>
40-
(None, Identifier.of(Array.empty, name))
41-
case Seq(catalogName, tail @ _*) =>
42-
try {
43-
val catalog = lookup(catalogName)
44-
(Some(catalog), Identifier.of(tail.init.toArray, tail.last))
45-
} catch {
46-
case _: CatalogNotFoundException =>
47-
(None, Identifier.of(parts.init.toArray, parts.last))
48-
}
49-
}
37+
def unapply(parts: Seq[String]): Some[CatalogObjectIdentifier] = parts match {
38+
case Seq(name) =>
39+
Some((None, Identifier.of(Array.empty, name)))
40+
case Seq(catalogName, tail @ _*) =>
41+
try {
42+
Some((Some(lookupCatalog(catalogName)), Identifier.of(tail.init.toArray, tail.last)))
43+
} catch {
44+
case _: CatalogNotFoundException =>
45+
Some((None, Identifier.of(parts.init.toArray, parts.last)))
46+
}
5047
}
5148
}
5249

5350
/**
5451
* Extract legacy table identifier from a multi-part identifier.
5552
*
56-
* For legacy support only. Please use
57-
* [[org.apache.spark.sql.catalog.v2.LookupCatalog.CatalogObjectIdentifier]] in DSv2 code paths.
53+
* For legacy support only. Please use [[CatalogObjectIdentifier]] instead on DSv2 code paths.
5854
*/
5955
object AsTableIdentifier {
6056
def unapply(parts: Seq[String]): Option[TableIdentifier] = parts match {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import scala.collection.mutable.ArrayBuffer
2424
import scala.util.Random
2525

2626
import org.apache.spark.sql.AnalysisException
27-
import org.apache.spark.sql.catalog.v2.{CatalogPlugin, LookupCatalog}
27+
import org.apache.spark.sql.catalog.v2.{CatalogNotFoundException, CatalogPlugin, LookupCatalog}
2828
import org.apache.spark.sql.catalyst._
2929
import org.apache.spark.sql.catalyst.catalog._
3030
import org.apache.spark.sql.catalyst.encoders.OuterScopes
@@ -96,18 +96,15 @@ object AnalysisContext {
9696
class Analyzer(
9797
catalog: SessionCatalog,
9898
conf: SQLConf,
99-
maxIterations: Int,
100-
override val lookupCatalog: Option[(String) => CatalogPlugin] = None)
99+
maxIterations: Int)
101100
extends RuleExecutor[LogicalPlan] with CheckAnalysis with LookupCatalog {
102101

103102
def this(catalog: SessionCatalog, conf: SQLConf) = {
104103
this(catalog, conf, conf.optimizerMaxIterations)
105104
}
106105

107-
def this(lookupCatalog: Option[(String) => CatalogPlugin], catalog: SessionCatalog,
108-
conf: SQLConf) = {
109-
this(catalog, conf, conf.optimizerMaxIterations, lookupCatalog)
110-
}
106+
override protected def lookupCatalog(name: String): CatalogPlugin =
107+
throw new CatalogNotFoundException("No catalog lookup function")
111108

112109
def executeAndCheck(plan: LogicalPlan, tracker: QueryPlanningTracker): LogicalPlan = {
113110
AnalysisHelper.markInAnalyzer {
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.sql.catalyst.catalog.v2
18+
19+
import org.scalatest.Inside
20+
import org.scalatest.Matchers._
21+
22+
import org.apache.spark.SparkFunSuite
23+
import org.apache.spark.sql.catalog.v2.{CatalogNotFoundException, CatalogPlugin, Identifier, LookupCatalog}
24+
import org.apache.spark.sql.catalyst.TableIdentifier
25+
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
26+
import org.apache.spark.sql.util.CaseInsensitiveStringMap
27+
28+
private case class TestCatalogPlugin(override val name: String) extends CatalogPlugin {
29+
30+
override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = Unit
31+
}
32+
33+
class LookupCatalogSuite extends SparkFunSuite with LookupCatalog with Inside {
34+
import CatalystSqlParser._
35+
36+
private val catalogs = Seq("prod", "test").map(x => x -> new TestCatalogPlugin(x)).toMap
37+
38+
override def lookupCatalog(name: String): CatalogPlugin =
39+
catalogs.getOrElse(name, throw new CatalogNotFoundException(s"$name not found"))
40+
41+
test("catalog object identifier") {
42+
Seq(
43+
("tbl", None, Seq.empty, "tbl"),
44+
("db.tbl", None, Seq("db"), "tbl"),
45+
("prod.func", catalogs.get("prod"), Seq.empty, "func"),
46+
("ns1.ns2.tbl", None, Seq("ns1", "ns2"), "tbl"),
47+
("prod.db.tbl", catalogs.get("prod"), Seq("db"), "tbl"),
48+
("test.db.tbl", catalogs.get("test"), Seq("db"), "tbl"),
49+
("test.ns1.ns2.ns3.tbl", catalogs.get("test"), Seq("ns1", "ns2", "ns3"), "tbl"),
50+
("`db.tbl`", None, Seq.empty, "db.tbl"),
51+
("parquet.`file:/tmp/db.tbl`", None, Seq("parquet"), "file:/tmp/db.tbl"),
52+
("`org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`", None,
53+
Seq("org.apache.spark.sql.json"), "s3://buck/tmp/abc.json")).foreach {
54+
case (sql, expectedCatalog, namespace, name) =>
55+
inside(parseMultipartIdentifier(sql)) {
56+
case CatalogObjectIdentifier(catalog, ident) =>
57+
catalog shouldEqual expectedCatalog
58+
ident shouldEqual Identifier.of(namespace.toArray, name)
59+
}
60+
}
61+
}
62+
63+
test("table identifier") {
64+
Seq(
65+
("tbl", "tbl", None),
66+
("db.tbl", "tbl", Some("db")),
67+
("`db.tbl`", "db.tbl", None),
68+
("parquet.`file:/tmp/db.tbl`", "file:/tmp/db.tbl", Some("parquet")),
69+
("`org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`", "s3://buck/tmp/abc.json",
70+
Some("org.apache.spark.sql.json"))).foreach {
71+
case (sql, table, db) =>
72+
inside (parseMultipartIdentifier(sql)) {
73+
case AsTableIdentifier(ident) =>
74+
ident shouldEqual TableIdentifier(table, db)
75+
}
76+
}
77+
Seq(
78+
"prod.func",
79+
"prod.db.tbl",
80+
"ns1.ns2.tbl").foreach { sql =>
81+
parseMultipartIdentifier(sql) match {
82+
case AsTableIdentifier(_) =>
83+
fail(s"$sql should not be resolved as TableIdentifier")
84+
case _ =>
85+
}
86+
}
87+
}
88+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/v2/ResolveMultipartIdentifierSuite.scala

Lines changed: 0 additions & 99 deletions
This file was deleted.

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ case class DataSourceResolution(
4141

4242
import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._
4343

44-
override def lookupCatalog: Option[String => CatalogPlugin] = Some(findCatalog)
44+
override protected def lookupCatalog(name: String): CatalogPlugin = findCatalog(name)
4545

4646
def defaultCatalog: Option[CatalogPlugin] = conf.defaultV2Catalog.map(findCatalog)
4747

0 commit comments

Comments
 (0)