Skip to content

Commit aef2268

Browse files
RussellSpitzerJackey Lee
authored andcommitted
[SPARK-25003][PYSPARK] Use SessionExtensions in Pyspark
Master ## What changes were proposed in this pull request? Previously Pyspark used the private constructor for SparkSession when building that object. This resulted in a SparkSession without checking the sql.extensions parameter for additional session extensions. To fix this we instead use the Session.builder() path as SparkR uses, this loads the extensions and allows their use in PySpark. ## How was this patch tested? An integration test was added which mimics the Scala test for the same feature. Please review http://spark.apache.org/contributing.html before opening a pull request. Closes apache#21990 from RussellSpitzer/SPARK-25003-master. Authored-by: Russell Spitzer <[email protected]> Signed-off-by: hyukjinkwon <[email protected]>
1 parent 770be8a commit aef2268

2 files changed

Lines changed: 80 additions & 18 deletions

File tree

python/pyspark/sql/tests.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3837,6 +3837,48 @@ def test_query_execution_listener_on_collect_with_arrow(self):
38373837
"The callback from the query execution listener should be called after 'toPandas'")
38383838

38393839

3840+
class SparkExtensionsTest(unittest.TestCase):
3841+
# These tests are separate because it uses 'spark.sql.extensions' which is
3842+
# static and immutable. This can't be set or unset, for example, via `spark.conf`.
3843+
3844+
@classmethod
3845+
def setUpClass(cls):
3846+
import glob
3847+
from pyspark.find_spark_home import _find_spark_home
3848+
3849+
SPARK_HOME = _find_spark_home()
3850+
filename_pattern = (
3851+
"sql/core/target/scala-*/test-classes/org/apache/spark/sql/"
3852+
"SparkSessionExtensionSuite.class")
3853+
if not glob.glob(os.path.join(SPARK_HOME, filename_pattern)):
3854+
raise unittest.SkipTest(
3855+
"'org.apache.spark.sql.SparkSessionExtensionSuite' is not "
3856+
"available. Will skip the related tests.")
3857+
3858+
# Note that 'spark.sql.extensions' is a static immutable configuration.
3859+
cls.spark = SparkSession.builder \
3860+
.master("local[4]") \
3861+
.appName(cls.__name__) \
3862+
.config(
3863+
"spark.sql.extensions",
3864+
"org.apache.spark.sql.MyExtensions") \
3865+
.getOrCreate()
3866+
3867+
@classmethod
3868+
def tearDownClass(cls):
3869+
cls.spark.stop()
3870+
3871+
def test_use_custom_class_for_extensions(self):
3872+
self.assertTrue(
3873+
self.spark._jsparkSession.sessionState().planner().strategies().contains(
3874+
self.spark._jvm.org.apache.spark.sql.MySparkStrategy(self.spark._jsparkSession)),
3875+
"MySparkStrategy not found in active planner strategies")
3876+
self.assertTrue(
3877+
self.spark._jsparkSession.sessionState().analyzer().extendedResolutionRules().contains(
3878+
self.spark._jvm.org.apache.spark.sql.MyRule(self.spark._jsparkSession)),
3879+
"MyRule not found in extended resolution rules")
3880+
3881+
38403882
class SparkSessionTests(PySparkTestCase):
38413883

38423884
# This test is separate because it's closely related with session's start and stop.

sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,17 @@ class SparkSession private(
8484
// The call site where this SparkSession was constructed.
8585
private val creationSite: CallSite = Utils.getCallSite()
8686

87+
/**
88+
* Constructor used in Pyspark. Contains explicit application of Spark Session Extensions
89+
* which otherwise only occurs during getOrCreate. We cannot add this to the default constructor
90+
* since that would cause every new session to reinvoke Spark Session Extensions on the currently
91+
* running extensions.
92+
*/
8793
private[sql] def this(sc: SparkContext) {
88-
this(sc, None, None, new SparkSessionExtensions)
94+
this(sc, None, None,
95+
SparkSession.applyExtensions(
96+
sc.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS),
97+
new SparkSessionExtensions))
8998
}
9099

91100
sparkContext.assertNotStopped()
@@ -936,23 +945,9 @@ object SparkSession extends Logging {
936945
// Do not update `SparkConf` for existing `SparkContext`, as it's shared by all sessions.
937946
}
938947

939-
// Initialize extensions if the user has defined a configurator class.
940-
val extensionConfOption = sparkContext.conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS)
941-
if (extensionConfOption.isDefined) {
942-
val extensionConfClassName = extensionConfOption.get
943-
try {
944-
val extensionConfClass = Utils.classForName(extensionConfClassName)
945-
val extensionConf = extensionConfClass.newInstance()
946-
.asInstanceOf[SparkSessionExtensions => Unit]
947-
extensionConf(extensions)
948-
} catch {
949-
// Ignore the error if we cannot find the class or when the class has the wrong type.
950-
case e @ (_: ClassCastException |
951-
_: ClassNotFoundException |
952-
_: NoClassDefFoundError) =>
953-
logWarning(s"Cannot use $extensionConfClassName to configure session extensions.", e)
954-
}
955-
}
948+
applyExtensions(
949+
sparkContext.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS),
950+
extensions)
956951

957952
session = new SparkSession(sparkContext, None, None, extensions)
958953
options.foreach { case (k, v) => session.initialSessionOptions.put(k, v) }
@@ -1137,4 +1132,29 @@ object SparkSession extends Logging {
11371132
SparkSession.clearDefaultSession()
11381133
}
11391134
}
1135+
1136+
/**
1137+
* Initialize extensions for given extension classname. This class will be applied to the
1138+
* extensions passed into this function.
1139+
*/
1140+
private def applyExtensions(
1141+
extensionOption: Option[String],
1142+
extensions: SparkSessionExtensions): SparkSessionExtensions = {
1143+
if (extensionOption.isDefined) {
1144+
val extensionConfClassName = extensionOption.get
1145+
try {
1146+
val extensionConfClass = Utils.classForName(extensionConfClassName)
1147+
val extensionConf = extensionConfClass.newInstance()
1148+
.asInstanceOf[SparkSessionExtensions => Unit]
1149+
extensionConf(extensions)
1150+
} catch {
1151+
// Ignore the error if we cannot find the class or when the class has the wrong type.
1152+
case e@(_: ClassCastException |
1153+
_: ClassNotFoundException |
1154+
_: NoClassDefFoundError) =>
1155+
logWarning(s"Cannot use $extensionConfClassName to configure session extensions.", e)
1156+
}
1157+
}
1158+
extensions
1159+
}
11401160
}

0 commit comments

Comments
 (0)