Skip to content
Closed
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
bcdd070
add counter to monitor number of active SparkSessions
Jun 5, 2019
f333a30
clear session always in stop()
Jun 5, 2019
8fc95e9
added tests
Jun 6, 2019
ff60b84
Merge branch 'master' of github.com:apache/spark into vinooganesh/SPA…
Jun 10, 2019
f69cabb
responding to PR comments
Jun 10, 2019
3fafd2a
cleanup
Jun 10, 2019
0c9c426
adding ticket number to test
Jun 10, 2019
843491f
style
Jun 10, 2019
92c7b22
addressing sean's PR and updating test
Jun 14, 2019
d4e4e27
Merge branch 'master' into vinooganesh/SPARK-27958
Mar 21, 2020
61c6fed
first iteration of new proposal
Mar 27, 2020
b586bf9
Merge remote-tracking branch 'origin/master' into vinooganesh/SPARK-2…
Apr 5, 2020
387acb1
testing and style
Apr 5, 2020
99c5f64
style fix
Apr 5, 2020
7613046
Merge remote-tracking branch 'origin' into vinooganesh/SPARK-27958
vinooganesh Apr 30, 2020
7a70369
Merge branch 'master' into vinooganesh/SPARK-27958
vinooganesh May 5, 2020
5a1e0ab
remove unnecessary s
vinooganesh May 6, 2020
38dea00
Merge branch 'master' into vinooganesh/SPARK-27958
vinooganesh May 6, 2020
0c7e9df
remove lifecycle methods - context tracks listener
vinooganesh May 12, 2020
0284c79
remove unnecessary import
vinooganesh May 12, 2020
19f45da
Merge branch 'master' into vinooganesh/SPARK-27958
vinooganesh May 12, 2020
9573805
add ticket number to test
vinooganesh May 13, 2020
e877ee1
Merge branch 'master' into vinooganesh/SPARK-27958
vinooganesh May 18, 2020
e5563a7
move logic to listener
vinooganesh May 19, 2020
e5ef33a
adding test and cleanup
vinooganesh May 20, 2020
cfa1462
same class, don't need SparkSession
vinooganesh May 21, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 121 additions & 17 deletions sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
package org.apache.spark.sql

import java.io.Closeable
import java.util.UUID
import java.util.concurrent.TimeUnit._
import java.util.concurrent.atomic.AtomicReference
import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference}

import scala.collection.JavaConverters._
import scala.reflect.runtime.universe.TypeTag
Expand All @@ -31,6 +32,7 @@ import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
import org.apache.spark.sql.SparkSession.defaultSession
import org.apache.spark.sql.catalog.Catalog
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
Expand Down Expand Up @@ -87,6 +89,27 @@ class SparkSession private(
// The call site where this SparkSession was constructed.
private val creationSite: CallSite = Utils.getCallSite()

private val _sessionListener: SparkListener = new SparkListener {
override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = {
defaultSession.set(null)
}
}
// Used to manage state in the the `SparkSession` singleton
private[spark] val _sessionId: UUID = UUID.randomUUID
private[spark] val _terminated: AtomicBoolean = new AtomicBoolean(false)
sparkContext.addSparkListener(_sessionListener)

private[spark] def assertNotTerminated(): Unit = {
if (_terminated.get()) {
throw new IllegalStateException(
s"""Cannot call methods on a terminated SparkSession.
|This terminated SparkSession was created at:
|
|${creationSite.longForm}
""".stripMargin)
}
}

/**
* Constructor used in Pyspark. Contains explicit application of Spark Session Extensions
* which otherwise only occurs during getOrCreate. We cannot add this to the default constructor
Expand All @@ -108,6 +131,8 @@ class SparkSession private(
.getOrElse(SQLConf.getFallbackConf)
})

def removeListener(): Unit = sparkContext.removeSparkListener(_sessionListener)

/**
* The version of Spark on which this application is running.
*
Expand Down Expand Up @@ -276,14 +301,18 @@ class SparkSession private(
* @since 2.0.0
*/
@transient
lazy val emptyDataFrame: DataFrame = Dataset.ofRows(self, LocalRelation())
lazy val emptyDataFrame: DataFrame = {
assertNotTerminated()
Dataset.ofRows(self, LocalRelation())
}

/**
* Creates a new [[Dataset]] of type T containing zero elements.
*
* @return 2.0.0
*/
def emptyDataset[T: Encoder]: Dataset[T] = {
assertNotTerminated()
val encoder = implicitly[Encoder[T]]
new Dataset(self, LocalRelation(encoder.schema.toAttributes), encoder)
}
Expand All @@ -294,6 +323,7 @@ class SparkSession private(
* @since 2.0.0
*/
def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = withActive {
assertNotTerminated()
val encoder = Encoders.product[A]
Dataset.ofRows(self, ExternalRDD(rdd, self)(encoder))
}
Expand All @@ -304,6 +334,7 @@ class SparkSession private(
* @since 2.0.0
*/
def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = withActive {
assertNotTerminated()
val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
val attributeSeq = schema.toAttributes
Dataset.ofRows(self, LocalRelation.fromProduct(attributeSeq, data))
Expand Down Expand Up @@ -342,6 +373,7 @@ class SparkSession private(
*/
@DeveloperApi
def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = withActive {
assertNotTerminated()
// TODO: use MutableProjection when rowRDD is another DataFrame and the applied
// schema differs from the existing schema on any field data type.
val encoder = RowEncoder(schema)
Expand All @@ -360,6 +392,7 @@ class SparkSession private(
*/
@DeveloperApi
def createDataFrame(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = {
assertNotTerminated()
createDataFrame(rowRDD.rdd, schema)
}

Expand All @@ -373,6 +406,7 @@ class SparkSession private(
*/
@DeveloperApi
def createDataFrame(rows: java.util.List[Row], schema: StructType): DataFrame = withActive {
assertNotTerminated()
Dataset.ofRows(self, LocalRelation.fromExternalRows(schema.toAttributes, rows.asScala))
}

Expand All @@ -385,6 +419,7 @@ class SparkSession private(
* @since 2.0.0
*/
def createDataFrame(rdd: RDD[_], beanClass: Class[_]): DataFrame = withActive {
assertNotTerminated()
val attributeSeq: Seq[AttributeReference] = getSchema(beanClass)
val className = beanClass.getName
val rowRdd = rdd.mapPartitions { iter =>
Expand All @@ -403,6 +438,7 @@ class SparkSession private(
* @since 2.0.0
*/
def createDataFrame(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = {
assertNotTerminated()
createDataFrame(rdd.rdd, beanClass)
}

Expand All @@ -414,6 +450,7 @@ class SparkSession private(
* @since 1.6.0
*/
def createDataFrame(data: java.util.List[_], beanClass: Class[_]): DataFrame = withActive {
assertNotTerminated()
val attrSeq = getSchema(beanClass)
val rows = SQLContext.beansToRows(data.asScala.iterator, beanClass, attrSeq)
Dataset.ofRows(self, LocalRelation(attrSeq, rows.toSeq))
Expand All @@ -425,6 +462,7 @@ class SparkSession private(
* @since 2.0.0
*/
def baseRelationToDataFrame(baseRelation: BaseRelation): DataFrame = {
assertNotTerminated()
Dataset.ofRows(self, LogicalRelation(baseRelation))
}

Expand Down Expand Up @@ -459,7 +497,9 @@ class SparkSession private(
*
* @since 2.0.0
*/

def createDataset[T : Encoder](data: Seq[T]): Dataset[T] = {
assertNotTerminated()
val enc = encoderFor[T]
val toRow = enc.createSerializer()
val attributes = enc.schema.toAttributes
Expand All @@ -468,6 +508,7 @@ class SparkSession private(
Dataset[T](self, plan)
}


/**
* Creates a [[Dataset]] from an RDD of a given type. This method requires an
* encoder (to convert a JVM object of type `T` to and from the internal Spark SQL representation)
Expand All @@ -477,6 +518,7 @@ class SparkSession private(
* @since 2.0.0
*/
def createDataset[T : Encoder](data: RDD[T]): Dataset[T] = {
assertNotTerminated()
Dataset[T](self, ExternalRDD(data, self))
}

Expand All @@ -496,6 +538,7 @@ class SparkSession private(
* @since 2.0.0
*/
def createDataset[T : Encoder](data: java.util.List[T]): Dataset[T] = {
assertNotTerminated()
createDataset(data.asScala)
}

Expand All @@ -505,7 +548,10 @@ class SparkSession private(
*
* @since 2.0.0
*/
def range(end: Long): Dataset[java.lang.Long] = range(0, end)
def range(end: Long): Dataset[java.lang.Long] = {
assertNotTerminated()
range(0, end)
}

/**
* Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements
Expand All @@ -514,6 +560,7 @@ class SparkSession private(
* @since 2.0.0
*/
def range(start: Long, end: Long): Dataset[java.lang.Long] = {
assertNotTerminated()
range(start, end, step = 1, numPartitions = sparkContext.defaultParallelism)
}

Expand All @@ -524,6 +571,7 @@ class SparkSession private(
* @since 2.0.0
*/
def range(start: Long, end: Long, step: Long): Dataset[java.lang.Long] = {
assertNotTerminated()
range(start, end, step, numPartitions = sparkContext.defaultParallelism)
}

Expand All @@ -535,6 +583,7 @@ class SparkSession private(
* @since 2.0.0
*/
def range(start: Long, end: Long, step: Long, numPartitions: Int): Dataset[java.lang.Long] = {
assertNotTerminated()
new Dataset(self, Range(start, end, step, numPartitions), Encoders.LONG)
}

Expand All @@ -545,6 +594,7 @@ class SparkSession private(
catalystRows: RDD[InternalRow],
schema: StructType,
isStreaming: Boolean = false): DataFrame = {
assertNotTerminated()
// TODO: use MutableProjection when rowRDD is another DataFrame and the applied
// schema differs from the existing schema on any field data type.
val logicalPlan = LogicalRDD(
Expand Down Expand Up @@ -578,14 +628,17 @@ class SparkSession private(
* @since 2.0.0
*/
def table(tableName: String): DataFrame = {
assertNotTerminated()
table(sessionState.sqlParser.parseMultipartIdentifier(tableName))
}

private[sql] def table(multipartIdentifier: Seq[String]): DataFrame = {
assertNotTerminated()
Dataset.ofRows(self, UnresolvedRelation(multipartIdentifier))
}

private[sql] def table(tableIdent: TableIdentifier): DataFrame = {
assertNotTerminated()
Dataset.ofRows(self, UnresolvedRelation(tableIdent))
}

Expand All @@ -600,6 +653,7 @@ class SparkSession private(
* @since 2.0.0
*/
def sql(sqlText: String): DataFrame = withActive {
assertNotTerminated()
val tracker = new QueryPlanningTracker
val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) {
sessionState.sqlParser.parsePlan(sqlText)
Expand All @@ -624,6 +678,7 @@ class SparkSession private(
*/
@Unstable
def executeCommand(runner: String, command: String, options: Map[String, String]): DataFrame = {
assertNotTerminated()
DataSource.lookupDataSource(runner, sessionState.conf) match {
case source if classOf[ExternalCommandRunner].isAssignableFrom(source) =>
Dataset.ofRows(self, ExternalCommandExecutor(
Expand All @@ -644,7 +699,10 @@ class SparkSession private(
*
* @since 2.0.0
*/
def read: DataFrameReader = new DataFrameReader(self)
def read: DataFrameReader = {
assertNotTerminated()
new DataFrameReader(self)
}

/**
* Returns a `DataStreamReader` that can be used to read streaming data in as a `DataFrame`.
Expand All @@ -655,7 +713,10 @@ class SparkSession private(
*
* @since 2.0.0
*/
def readStream: DataStreamReader = new DataStreamReader(self)
def readStream: DataStreamReader = {
assertNotTerminated()
new DataStreamReader(self)
}

/**
* Executes some code block and prints to stdout the time taken to execute the block. This is
Expand Down Expand Up @@ -691,21 +752,56 @@ class SparkSession private(
}
// scalastyle:on

/**
* Lifecycle method that cleans up state of spark session, and mark session
* "ended" forever. This differs from `stop()` or `stopContext()` as it keeps
* the underlying `SparkContext` alive, while only getting
*/

def terminate(): Unit = {
// Session is still active
if (!_terminated.get()) {
sparkContext.removeSparkListener(this._sessionListener)
SparkSession.removeTerminatedSession(_sessionId)
_terminated.set(true)
}
else {
throw new IllegalStateException(
s"""Cannot call methods on a terminated SparkSession. Call
Comment thread
vinooganesh marked this conversation as resolved.
Outdated
|getOrCreate() to create a new session.
|""".stripMargin)
}
}

/**
* Stop the underlying `SparkContext`.
*
* @since 2.0.0
*/
def stop(): Unit = {
def stopContext(): Unit = {
if (!_terminated.get()) {
// stopping the context should also terminate this session
terminate()
SparkSession.removeTerminatedSession(_sessionId)
SparkSession.clearDefaultSession()
SparkSession.clearActiveSession()
}
sparkContext.stop()
}

/**
* Synonym for `stopContext()`.
*
* @since 2.0.0
*/
def stop(): Unit = stopContext()

/**
* Synonym for `stop()`.
*
* @since 2.1.0
*/
override def close(): Unit = stop()
override def close(): Unit = stopContext()

/**
* Parses the data type in our internal string representation. The data type string should
Expand Down Expand Up @@ -911,7 +1007,7 @@ object SparkSession extends Logging {

// Global synchronization so we will only set the default session once.
SparkSession.synchronized {
// If the current thread does not have an active session, get it from the global session.
// If the current thread does not have an active session, get it from the default session.
session = defaultSession.get()
if ((session ne null) && !session.sparkContext.isStopped) {
applyModifiableSettings(session)
Expand Down Expand Up @@ -940,15 +1036,6 @@ object SparkSession extends Logging {
options.foreach { case (k, v) => session.initialSessionOptions.put(k, v) }
setDefaultSession(session)
setActiveSession(session)

// Register a successfully instantiated context to the singleton. This should be at the
// end of the class definition so that the singleton is updated only if there is no
// exception in the construction of the instance.
sparkContext.addSparkListener(new SparkListener {
override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = {
defaultSession.set(null)
}
})
}

return session
Expand Down Expand Up @@ -978,6 +1065,23 @@ object SparkSession extends Logging {
*/
def builder(): Builder = new Builder

def removeTerminatedSession(sessionId: UUID): Unit = {
// clean up active session
if (getActiveSession.isDefined) {
val activeSession = getActiveSession.get
if(sessionId.equals(activeSession._sessionId)) {
clearActiveSession()
}
}
// clean up default session
if (getDefaultSession.isDefined) {
val defaultSession = getDefaultSession.get
if(sessionId.equals(defaultSession._sessionId)) {
clearDefaultSession()
}
}
}

/**
* Changes the SparkSession that will be returned in this thread and its children when
* SparkSession.getOrCreate() is called. This can be used to ensure that a given thread receives
Expand Down
Loading