diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/progress.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/progress.scala index ebd13bc248f9..82a7db32543a 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/progress.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/progress.scala @@ -203,7 +203,7 @@ private[spark] object StreamingQueryProgress { } private[spark] def jsonString(progress: StreamingQueryProgress): String = - mapper.writeValueAsString(progress) + progress.json private[spark] def fromJson(json: String): StreamingQueryProgress = mapper.readValue[StreamingQueryProgress](json) diff --git a/python/pyspark/sql/streaming/listener.py b/python/pyspark/sql/streaming/listener.py index c1c9dce04731..e168f017aca9 100644 --- a/python/pyspark/sql/streaming/listener.py +++ b/python/pyspark/sql/streaming/listener.py @@ -394,9 +394,11 @@ def errorClassOnException(self) -> Optional[str]: return self._errorClassOnException -class StreamingQueryProgress: +class StreamingQueryProgress(dict): """ .. versionadded:: 3.4.0 + .. versionchanged:: 4.0.0 + Becomes a subclass of dict Notes ----- @@ -473,6 +475,10 @@ def fromJObject(cls, jprogress: "JavaObject") -> "StreamingQueryProgress": @classmethod def fromJson(cls, j: Dict[str, Any]) -> "StreamingQueryProgress": + num_input_rows = j.get("numInputRows", None) + input_rows_per_sec = j.get("inputRowsPerSecond", None) + processed_rows_per_sec = j.get("processedRowsPerSecond", None) + return cls( jdict=j, id=uuid.UUID(j["id"]), @@ -486,9 +492,9 @@ def fromJson(cls, j: Dict[str, Any]) -> "StreamingQueryProgress": stateOperators=[StateOperatorProgress.fromJson(s) for s in j["stateOperators"]], sources=[SourceProgress.fromJson(s) for s in j["sources"]], sink=SinkProgress.fromJson(j["sink"]), - numInputRows=j["numInputRows"], - inputRowsPerSecond=j["inputRowsPerSecond"], - processedRowsPerSecond=j["processedRowsPerSecond"], + numInputRows=j["numInputRows"] if "numInputRows" in j else None, + inputRowsPerSecond=j["inputRowsPerSecond"] if "inputRowsPerSecond" in j else None, + processedRowsPerSecond=j["processedRowsPerSecond"] if "processedRowsPerSecond" in j else None, observedMetrics={ k: Row(*row_dict.keys())(*row_dict.values()) # Assume no nested rows for k, row_dict in j["observedMetrics"].items() @@ -600,21 +606,30 @@ def numInputRows(self) -> int: """ The aggregate (across all sources) number of records processed in a trigger. """ - return self._numInputRows + if self._numInputRows is not None: + return self._numInputRows + else: + return sum(s.numInputRows for s in self.sources) @property def inputRowsPerSecond(self) -> float: """ The aggregate (across all sources) rate of data arriving. """ - return self._inputRowsPerSecond + if self._inputRowsPerSecond is not None: + return self._inputRowsPerSecond + else: + return sum(s.inputRowsPerSecond for s in self.sources) @property def processedRowsPerSecond(self) -> float: """ The aggregate (across all sources) rate at which Spark is processing data. """ - return self._processedRowsPerSecond + if self._processedRowsPerSecond is not None: + return self._processedRowsPerSecond + else: + return sum(s.processedRowsPerSecond for s in self.sources) @property def json(self) -> str: diff --git a/python/pyspark/sql/tests/streaming/test_streaming.py b/python/pyspark/sql/tests/streaming/test_streaming.py index 1799f0d1336e..b614d0670262 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming.py +++ b/python/pyspark/sql/tests/streaming/test_streaming.py @@ -28,7 +28,7 @@ class StreamingTestsMixin: def test_streaming_query_functions_basic(self): - df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load() + df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") query = ( df.writeStream.format("memory") .queryName("test_streaming_query_functions_basic") diff --git a/python/pyspark/sql/tests/streaming/test_streaming_listener.py b/python/pyspark/sql/tests/streaming/test_streaming_listener.py index 15f5575d3647..7d3d2c6f893f 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming_listener.py +++ b/python/pyspark/sql/tests/streaming/test_streaming_listener.py @@ -194,6 +194,17 @@ def check_sink_progress(self, progress): self.assertTrue(isinstance(progress.numOutputRows, int)) self.assertTrue(isinstance(progress.metrics, dict)) + def test_streaming_last_progress(self): + try: + df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") + query = df.writeStream.format("noop").queryName("test_streaming_progress").start() + query.processAllAvailable() + + progress = StreamingQueryProgress.fromJson(query.lastProgress) + self.check_streaming_query_progress(progress, False) + finally: + query.stop() + # This is a generic test work for both classic Spark and Spark Connect def test_listener_observed_metrics(self): class MyErrorListener(StreamingQueryListener): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala index 05323d9d0381..c81412b30102 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala @@ -187,7 +187,7 @@ private[spark] object StreamingQueryProgress { } private[spark] def jsonString(progress: StreamingQueryProgress): String = - mapper.writeValueAsString(progress) + progress.json private[spark] def fromJson(json: String): StreamingQueryProgress = mapper.readValue[StreamingQueryProgress](json)