Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/docs/source/reference/pyspark.sql/spark_session.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ See also :class:`SparkSession`.
SparkSession.dataSource
SparkSession.getActiveSession
SparkSession.getTags
SparkSession.interruptAll
SparkSession.interruptTag
SparkSession.newSession
SparkSession.profile
SparkSession.removeTag
Expand Down Expand Up @@ -86,8 +88,6 @@ Spark Connect Only
SparkSession.clearProgressHandlers
SparkSession.client
SparkSession.copyFromLocalToFs
SparkSession.interruptAll
SparkSession.interruptOperation
SparkSession.interruptTag
SparkSession.registerProgressHandler
SparkSession.removeProgressHandler
34 changes: 24 additions & 10 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2197,13 +2197,15 @@ def copyFromLocalToFs(self, local_path: str, dest_path: str) -> None:
messageParameters={"feature": "SparkSession.copyFromLocalToFs"},
)

@remote_only
def interruptAll(self) -> List[str]:
"""
Interrupt all operations of this session currently running on the connected server.

.. versionadded:: 3.5.0

.. versionchanged:: 4.0.0
Supports Spark Classic.

Returns
-------
list of str
Expand All @@ -2213,18 +2215,25 @@ def interruptAll(self) -> List[str]:
-----
There is still a possibility of operation finishing just as it is interrupted.
"""
raise PySparkRuntimeError(
errorClass="ONLY_SUPPORTED_WITH_SPARK_CONNECT",
messageParameters={"feature": "SparkSession.interruptAll"},
)
java_list = self._jsparkSession.interruptAll()
python_list = list()

# Use iterator to manually iterate through Java list
java_iterator = java_list.iterator()
while java_iterator.hasNext():
python_list.append(str(java_iterator.next()))

return python_list

@remote_only
def interruptTag(self, tag: str) -> List[str]:
"""
Interrupt all operations of this session with the given operation tag.

.. versionadded:: 3.5.0

.. versionchanged:: 4.0.0
Supports Spark Classic.

Returns
-------
list of str
Expand All @@ -2234,10 +2243,15 @@ def interruptTag(self, tag: str) -> List[str]:
-----
There is still a possibility of operation finishing just as it is interrupted.
"""
raise PySparkRuntimeError(
errorClass="ONLY_SUPPORTED_WITH_SPARK_CONNECT",
messageParameters={"feature": "SparkSession.interruptTag"},
)
java_list = self._jsparkSession.interruptTag(tag)
python_list = list()

# Use iterator to manually iterate through Java list
java_iterator = java_list.iterator()
while java_iterator.hasNext():
python_list.append(str(java_iterator.next()))

return python_list

@remote_only
def interruptOperation(self, op_id: str) -> List[str]:
Expand Down
22 changes: 0 additions & 22 deletions python/pyspark/sql/tests/connect/test_parity_job_cancellation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,28 +32,6 @@ def func(target):
create_thread=lambda target, session: threading.Thread(target=func, args=(target,))
)

def test_interrupt_tag(self):
thread_ids = range(4)
self.check_job_cancellation(
lambda job_group: self.spark.addTag(job_group),
lambda job_group: self.spark.interruptTag(job_group),
thread_ids,
[i for i in thread_ids if i % 2 == 0],
[i for i in thread_ids if i % 2 != 0],
)
self.spark.clearTags()

def test_interrupt_all(self):
thread_ids = range(4)
self.check_job_cancellation(
lambda job_group: None,
lambda job_group: self.spark.interruptAll(),
thread_ids,
thread_ids,
[],
)
self.spark.clearTags()


if __name__ == "__main__":
import unittest
Expand Down
2 changes: 0 additions & 2 deletions python/pyspark/sql/tests/test_connect_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,9 +266,7 @@ def test_spark_session_compatibility(self):
"addArtifacts",
"clearProgressHandlers",
"copyFromLocalToFs",
"interruptAll",
"interruptOperation",
"interruptTag",
"newSession",
"registerProgressHandler",
"removeProgressHandler",
Expand Down
22 changes: 22 additions & 0 deletions python/pyspark/sql/tests/test_job_cancellation.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,28 @@ def get_outer_local_prop():
self.assertEqual(first, {"a", "b"})
self.assertEqual(second, {"a", "b", "c"})

def test_interrupt_tag(self):
thread_ids = range(4)
self.check_job_cancellation(
lambda job_group: self.spark.addTag(job_group),
lambda job_group: self.spark.interruptTag(job_group),
thread_ids,
[i for i in thread_ids if i % 2 == 0],
[i for i in thread_ids if i % 2 != 0],
)
self.spark.clearTags()

def test_interrupt_all(self):
thread_ids = range(4)
self.check_job_cancellation(
lambda job_group: None,
lambda job_group: self.spark.interruptAll(),
thread_ids,
thread_ids,
[],
)
self.spark.clearTags()


class JobCancellationTests(JobCancellationTestsMixin, ReusedSQLTestCase):
pass
Expand Down
1 change: 0 additions & 1 deletion python/pyspark/sql/tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,6 @@ def test_unsupported_api(self):
(lambda: session.client, "client"),
(session.addArtifacts, "addArtifact(s)"),
(lambda: session.copyFromLocalToFs("", ""), "copyFromLocalToFs"),
(lambda: session.interruptTag(""), "interruptTag"),
(lambda: session.interruptOperation(""), "interruptOperation"),
]

Expand Down
Loading