diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index f94b9c2115044..586af62269a0a 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -277,8 +277,10 @@ def getOrCreate(self) -> "SparkSession": # Do not update `SparkConf` for existing `SparkContext`, as it's shared # by all sessions. session = SparkSession(sc, options=self._options) - for key, value in self._options.items(): - session._jsparkSession.sessionState().conf().setConfString(key, value) + else: + getattr( + getattr(session._jvm, "SparkSession$"), "MODULE$" + ).applyModifiableSettings(session._jsparkSession, self._options) return session builder = Builder() @@ -291,7 +293,7 @@ def __init__( self, sparkContext: SparkContext, jsparkSession: Optional[JavaObject] = None, - options: Optional[Dict[str, Any]] = {}, + options: Dict[str, Any] = {}, ): from pyspark.sql.context import SQLContext @@ -304,8 +306,15 @@ def __init__( and not self._jvm.SparkSession.getDefaultSession().get().sparkContext().isStopped() ): jsparkSession = self._jvm.SparkSession.getDefaultSession().get() + getattr(getattr(self._jvm, "SparkSession$"), "MODULE$").applyModifiableSettings( + jsparkSession, options + ) else: jsparkSession = self._jvm.SparkSession(self._jsc.sc(), options) + else: + getattr(getattr(self._jvm, "SparkSession$"), "MODULE$").applyModifiableSettings( + jsparkSession, options + ) self._jsparkSession = jsparkSession self._jwrapped = self._jsparkSession.sqlContext() self._wrapped = SQLContext(self._sc, self, self._jwrapped) diff --git a/python/pyspark/sql/tests/test_session.py b/python/pyspark/sql/tests/test_session.py index 06771fac896ba..84fa23d9edfeb 100644 --- a/python/pyspark/sql/tests/test_session.py +++ b/python/pyspark/sql/tests/test_session.py @@ -273,12 +273,14 @@ def test_another_spark_session(self): session2 = None try: session1 = SparkSession.builder.config("key1", "value1").getOrCreate() - session2 = SparkSession.builder.config("key2", "value2").getOrCreate() + session2 = SparkSession.builder.config( + "spark.sql.codegen.comments", "true" + ).getOrCreate() self.assertEqual(session1.conf.get("key1"), "value1") self.assertEqual(session2.conf.get("key1"), "value1") - self.assertEqual(session1.conf.get("key2"), "value2") - self.assertEqual(session2.conf.get("key2"), "value2") + self.assertEqual(session1.conf.get("spark.sql.codegen.comments"), "false") + self.assertEqual(session2.conf.get("spark.sql.codegen.comments"), "false") self.assertEqual(session1.sparkContext, session2.sparkContext) self.assertEqual(session1.sparkContext.getConf().get("key1"), "value1") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index df110aa269e7b..cd101e502f8a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -930,7 +930,7 @@ object SparkSession extends Logging { // Get the session from current thread's active session. var session = activeThreadSession.get() if ((session ne null) && !session.sparkContext.isStopped) { - applyModifiableSettings(session) + applyModifiableSettings(session, new java.util.HashMap[String, String](options.asJava)) return session } @@ -939,7 +939,7 @@ object SparkSession extends Logging { // If the current thread does not have an active session, get it from the global session. session = defaultSession.get() if ((session ne null) && !session.sparkContext.isStopped) { - applyModifiableSettings(session) + applyModifiableSettings(session, new java.util.HashMap[String, String](options.asJava)) return session } @@ -967,22 +967,6 @@ object SparkSession extends Logging { return session } - - private def applyModifiableSettings(session: SparkSession): Unit = { - val (staticConfs, otherConfs) = - options.partition(kv => SQLConf.isStaticConfigKey(kv._1)) - - otherConfs.foreach { case (k, v) => session.sessionState.conf.setConfString(k, v) } - - if (staticConfs.nonEmpty) { - logWarning("Using an existing SparkSession; the static sql configurations will not take" + - " effect.") - } - if (otherConfs.nonEmpty) { - logWarning("Using an existing SparkSession; some spark core configurations may not take" + - " effect.") - } - } } /** @@ -1074,6 +1058,28 @@ object SparkSession extends Logging { throw new IllegalStateException("No active or default Spark session found"))) } + /** + * Apply modifiable settings to an existing [[SparkSession]]. This method are used + * both in Scala and Python, so put this under [[SparkSession]] object. + */ + private[sql] def applyModifiableSettings( + session: SparkSession, + options: java.util.HashMap[String, String]): Unit = { + val (staticConfs, otherConfs) = + options.asScala.partition(kv => SQLConf.isStaticConfigKey(kv._1)) + + otherConfs.foreach { case (k, v) => session.sessionState.conf.setConfString(k, v) } + + if (staticConfs.nonEmpty) { + logWarning("Using an existing SparkSession; the static sql configurations will not take" + + " effect.") + } + if (otherConfs.nonEmpty) { + logWarning("Using an existing SparkSession; some spark core configurations may not take" + + " effect.") + } + } + /** * Returns a cloned SparkSession with all specified configurations disabled, or * the original SparkSession if all configurations are already disabled.