diff --git a/python/pyspark/sql/connect/streaming/readwriter.py b/python/pyspark/sql/connect/streaming/readwriter.py index 4973bb5b6cf7..b5bb7f2a0912 100644 --- a/python/pyspark/sql/connect/streaming/readwriter.py +++ b/python/pyspark/sql/connect/streaming/readwriter.py @@ -446,6 +446,11 @@ def partitionBy(self, *cols: str) -> "DataStreamWriter": # type: ignore[misc] partitionBy.__doc__ = PySparkDataStreamWriter.partitionBy.__doc__ def queryName(self, queryName: str) -> "DataStreamWriter": + if not queryName or type(queryName) != str or len(queryName.strip()) == 0: + raise PySparkValueError( + error_class="VALUE_NOT_NON_EMPTY_STR", + message_parameters={"arg_name": "queryName", "arg_value": str(queryName)}, + ) self._write_proto.query_name = queryName return self @@ -605,7 +610,9 @@ def _start_internal( session=self._session, queryId=start_result.query_id.id, runId=start_result.query_id.run_id, - name=start_result.name, + # A Streaming Query cannot have empty string as name + # Spark throws error in that case, so this cast is safe + name=start_result.name if start_result.name != "" else None, ) if start_result.HasField("query_started_event_json"): diff --git a/python/pyspark/sql/streaming/query.py b/python/pyspark/sql/streaming/query.py index bcab8a104f1d..d3d58da3562b 100644 --- a/python/pyspark/sql/streaming/query.py +++ b/python/pyspark/sql/streaming/query.py @@ -114,7 +114,7 @@ def runId(self) -> str: @property def name(self) -> str: """ - Returns the user-specified name of the query, or null if not specified. + Returns the user-specified name of the query, or None if not specified. This name can be specified in the `org.apache.spark.sql.streaming.DataStreamWriter` as `dataframe.writeStream.queryName("query").start()`. This name, if set, must be unique across all active queries. @@ -127,14 +127,14 @@ def name(self) -> str: Returns ------- str - The user-specified name of the query, or null if not specified. + The user-specified name of the query, or None if not specified. Examples -------- >>> sdf = spark.readStream.format("rate").load() >>> sq = sdf.writeStream.format('memory').queryName('this_query').start() - Get the user-specified name of the query, or null if not specified. + Get the user-specified name of the query, or None if not specified. >>> sq.name 'this_query' diff --git a/python/pyspark/sql/tests/streaming/test_streaming.py b/python/pyspark/sql/tests/streaming/test_streaming.py index 1799f0d1336e..ea5ccb363088 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming.py +++ b/python/pyspark/sql/tests/streaming/test_streaming.py @@ -24,6 +24,7 @@ from pyspark.sql.functions import lit from pyspark.sql.types import StructType, StructField, IntegerType, StringType from pyspark.testing.sqlutils import ReusedSQLTestCase +from pyspark.errors import PySparkValueError class StreamingTestsMixin: @@ -58,6 +59,26 @@ def test_streaming_query_functions_basic(self): finally: query.stop() + def test_streaming_query_name_edge_case(self): + # Query name should be None when not specified + q1 = self.spark.readStream.format("rate").load().writeStream.format("noop").start() + self.assertEqual(q1.name, None) + + # Cannot set query name to be an empty string + error_thrown = False + try: + ( + self.spark.readStream.format("rate") + .load() + .writeStream.format("noop") + .queryName("") + .start() + ) + except PySparkValueError: + error_thrown = True + + self.assertTrue(error_thrown) + def test_stream_trigger(self): df = self.spark.readStream.format("text").load("python/test_support/sql/streaming")