Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,10 @@ class SparkContext(config: SparkConf) extends Logging {
// The call site where this SparkContext was constructed.
private val creationSite: CallSite = Utils.getCallSite()

// In order to prevent SparkContext from being created in executors.
SparkContext.assertOnDriver()
if (!config.get(ALLOW_SPARK_CONTEXT_IN_EXECUTORS)) {
// In order to prevent SparkContext from being created in executors.
SparkContext.assertOnDriver()
}

// In order to prevent multiple SparkContexts from being active at the same time, mark this
// context as having started construction.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1814,4 +1814,10 @@ package object config {
.bytesConf(ByteUnit.BYTE)
.createOptional

private[spark] val ALLOW_SPARK_CONTEXT_IN_EXECUTORS =
ConfigBuilder("spark.driver.allowSparkContextInExecutors")
.doc("If set to true, SparkContext can be created in executors.")
.version("3.0.1")
.booleanConf
.createWithDefault(true)
}
13 changes: 11 additions & 2 deletions core/src/test/scala/org/apache/spark/SparkContextSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -951,17 +951,26 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
}
}

test("SPARK-32160: Disallow to create SparkContext in executors") {
test("SPARK-32160: Disallow to create SparkContext in executors if the config is set") {
sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local-cluster[3, 1, 1024]"))

val error = intercept[SparkException] {
sc.range(0, 1).foreach { _ =>
new SparkContext(new SparkConf().setAppName("test").setMaster("local"))
new SparkContext(new SparkConf().setAppName("test").setMaster("local")
.set(ALLOW_SPARK_CONTEXT_IN_EXECUTORS, false))
}
}.getMessage()

assert(error.contains("SparkContext should only be created and accessed on the driver."))
}

test("SPARK-32160: Allow to create SparkContext in executors") {
sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local-cluster[3, 1, 1024]"))

sc.range(0, 1).foreach { _ =>
new SparkContext(new SparkConf().setAppName("test").setMaster("local")).stop()
}
}
}

object SparkContextSuite {
Expand Down
6 changes: 4 additions & 2 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,10 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
...
ValueError:...
"""
# In order to prevent SparkContext from being created in executors.
SparkContext._assert_on_driver()
if (conf is not None and
conf.get("spark.driver.allowSparkContextInExecutors", "true").lower() != "true"):
# In order to prevent SparkContext from being created in executors.
SparkContext._assert_on_driver()

self._callsite = first_spark_call() or CallSite(None, None, None)
if gateway is not None and gateway.gateway_parameters.auth_token is None:
Expand Down
20 changes: 18 additions & 2 deletions python/pyspark/tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,13 +268,29 @@ def test_resources(self):
self.assertEqual(len(resources), 0)

def test_disallow_to_create_spark_context_in_executors(self):
# SPARK-32160: SparkContext should not be created in executors.
# SPARK-32160: SparkContext should not created in executors if the config is set.

def create_spark_context():
conf = SparkConf().set("spark.driver.allowSparkContextInExecutors", "false")
with SparkContext(conf=conf):
pass

with SparkContext("local-cluster[3, 1, 1024]") as sc:
with self.assertRaises(Exception) as context:
sc.range(2).foreach(lambda _: SparkContext())
sc.range(2).foreach(lambda _: create_spark_context())
self.assertIn("SparkContext should only be created and accessed on the driver.",
str(context.exception))

def test_allow_to_create_spark_context_in_executors(self):
# SPARK-32160: SparkContext can be created in executors.

def create_spark_context():
with SparkContext():
pass

with SparkContext("local-cluster[3, 1, 1024]") as sc:
sc.range(2).foreach(lambda _: create_spark_context())


class ContextTestsWithResources(unittest.TestCase):

Expand Down
12 changes: 8 additions & 4 deletions sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext, TaskContext}
import org.apache.spark.annotation.{DeveloperApi, Experimental, Stable, Unstable}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.ALLOW_SPARK_CONTEXT_IN_EXECUTORS
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
import org.apache.spark.sql.catalog.Catalog
Expand Down Expand Up @@ -900,7 +901,13 @@ object SparkSession extends Logging {
* @since 2.0.0
*/
def getOrCreate(): SparkSession = synchronized {
assertOnDriver()
val sparkConf = new SparkConf()
options.foreach { case (k, v) => sparkConf.set(k, v) }

if (!sparkConf.get(ALLOW_SPARK_CONTEXT_IN_EXECUTORS)) {
assertOnDriver()
}

// Get the session from current thread's active session.
var session = activeThreadSession.get()
if ((session ne null) && !session.sparkContext.isStopped) {
Expand All @@ -919,9 +926,6 @@ object SparkSession extends Logging {

// No active nor global default session. Create a new one.
val sparkContext = userSuppliedContext.getOrElse {
val sparkConf = new SparkConf()
options.foreach { case (k, v) => sparkConf.set(k, v) }

// set a random app name if not given.
if (!sparkConf.contains("spark.app.name")) {
sparkConf.setAppName(java.util.UUID.randomUUID().toString)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ package org.apache.spark.sql

import org.scalatest.BeforeAndAfterEach

import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite}
import org.apache.spark.internal.config.ALLOW_SPARK_CONTEXT_IN_EXECUTORS
import org.apache.spark.internal.config.UI.UI_ENABLED
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.StaticSQLConf._
Expand Down Expand Up @@ -240,4 +241,27 @@ class SparkSessionBuilderSuite extends SparkFunSuite with BeforeAndAfterEach {
assert(session.conf.get(GLOBAL_TEMP_DATABASE) === "globaltempdb-spark-31532-2")
assert(session.conf.get(WAREHOUSE_PATH) === "SPARK-31532-db-2")
}

test("SPARK-32160: Disallow to create SparkSession in executors if the config is set") {
val session = SparkSession.builder().master("local-cluster[3, 1, 1024]").getOrCreate()

val error = intercept[SparkException] {
session.range(1).foreach { v =>
SparkSession.builder.master("local")
.config(ALLOW_SPARK_CONTEXT_IN_EXECUTORS.key, false).getOrCreate()
()
}
}.getMessage()

assert(error.contains("SparkSession should only be created and accessed on the driver."))
}

test("SPARK-32160: Allow to create SparkSession in executors") {
val session = SparkSession.builder().master("local-cluster[3, 1, 1024]").getOrCreate()

session.range(1).foreach { v =>
SparkSession.builder.master("local").getOrCreate().stop()
()
}
}
}