Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ private[spark] object StreamingQueryProgress {
}

private[spark] def jsonString(progress: StreamingQueryProgress): String =
mapper.writeValueAsString(progress)
progress.json

private[spark] def fromJson(json: String): StreamingQueryProgress =
mapper.readValue[StreamingQueryProgress](json)
Expand Down
29 changes: 22 additions & 7 deletions python/pyspark/sql/streaming/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,9 +394,11 @@ def errorClassOnException(self) -> Optional[str]:
return self._errorClassOnException


class StreamingQueryProgress:
class StreamingQueryProgress(dict):
"""
.. versionadded:: 3.4.0
.. versionchanged:: 4.0.0
Becomes a subclass of dict

Notes
-----
Expand Down Expand Up @@ -473,6 +475,10 @@ def fromJObject(cls, jprogress: "JavaObject") -> "StreamingQueryProgress":

@classmethod
def fromJson(cls, j: Dict[str, Any]) -> "StreamingQueryProgress":
num_input_rows = j.get("numInputRows", None)
input_rows_per_sec = j.get("inputRowsPerSecond", None)
processed_rows_per_sec = j.get("processedRowsPerSecond", None)

return cls(
jdict=j,
id=uuid.UUID(j["id"]),
Expand All @@ -486,9 +492,9 @@ def fromJson(cls, j: Dict[str, Any]) -> "StreamingQueryProgress":
stateOperators=[StateOperatorProgress.fromJson(s) for s in j["stateOperators"]],
sources=[SourceProgress.fromJson(s) for s in j["sources"]],
sink=SinkProgress.fromJson(j["sink"]),
numInputRows=j["numInputRows"],
inputRowsPerSecond=j["inputRowsPerSecond"],
processedRowsPerSecond=j["processedRowsPerSecond"],
numInputRows=j["numInputRows"] if "numInputRows" in j else None,
inputRowsPerSecond=j["inputRowsPerSecond"] if "inputRowsPerSecond" in j else None,
processedRowsPerSecond=j["processedRowsPerSecond"] if "processedRowsPerSecond" in j else None,
observedMetrics={
k: Row(*row_dict.keys())(*row_dict.values()) # Assume no nested rows
for k, row_dict in j["observedMetrics"].items()
Expand Down Expand Up @@ -600,21 +606,30 @@ def numInputRows(self) -> int:
"""
The aggregate (across all sources) number of records processed in a trigger.
"""
return self._numInputRows
if self._numInputRows is not None:
return self._numInputRows
else:
return sum(s.numInputRows for s in self.sources)

@property
def inputRowsPerSecond(self) -> float:
"""
The aggregate (across all sources) rate of data arriving.
"""
return self._inputRowsPerSecond
if self._inputRowsPerSecond is not None:
return self._inputRowsPerSecond
else:
return sum(s.inputRowsPerSecond for s in self.sources)

@property
def processedRowsPerSecond(self) -> float:
"""
The aggregate (across all sources) rate at which Spark is processing data.
"""
return self._processedRowsPerSecond
if self._processedRowsPerSecond is not None:
return self._processedRowsPerSecond
else:
return sum(s.processedRowsPerSecond for s in self.sources)

@property
def json(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/tests/streaming/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

class StreamingTestsMixin:
def test_streaming_query_functions_basic(self):
df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load()
df = self.spark.readStream.format("text").load("python/test_support/sql/streaming")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is just a minor fix it will make the test more stable. We call query.processAllAvailable below, which has a very small chance of being a indefinite blocking call on rate source.

query = (
df.writeStream.format("memory")
.queryName("test_streaming_query_functions_basic")
Expand Down
11 changes: 11 additions & 0 deletions python/pyspark/sql/tests/streaming/test_streaming_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,17 @@ def check_sink_progress(self, progress):
self.assertTrue(isinstance(progress.numOutputRows, int))
self.assertTrue(isinstance(progress.metrics, dict))

def test_streaming_last_progress(self):
try:
df = self.spark.readStream.format("text").load("python/test_support/sql/streaming")
query = df.writeStream.format("noop").queryName("test_streaming_progress").start()
query.processAllAvailable()

progress = StreamingQueryProgress.fromJson(query.lastProgress)
self.check_streaming_query_progress(progress, False)
finally:
query.stop()

# This is a generic test work for both classic Spark and Spark Connect
def test_listener_observed_metrics(self):
class MyErrorListener(StreamingQueryListener):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ private[spark] object StreamingQueryProgress {
}

private[spark] def jsonString(progress: StreamingQueryProgress): String =
mapper.writeValueAsString(progress)
progress.json

private[spark] def fromJson(json: String): StreamingQueryProgress =
mapper.readValue[StreamingQueryProgress](json)
Expand Down