-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-24396] [SS] [PYSPARK] Add Structured Streaming ForeachWriter for python #21477
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
701a455
0920260
f40dff6
d1cd933
8e30e8d
ecf3d88
d081110
1ab612f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,6 +30,7 @@ | |
| from pyspark.sql.readwriter import OptionUtils, to_str | ||
| from pyspark.sql.types import * | ||
| from pyspark.sql.utils import StreamingQueryException | ||
| from abc import ABCMeta, abstractmethod | ||
|
||
|
|
||
| __all__ = ["StreamingQuery", "StreamingQueryManager", "DataStreamReader", "DataStreamWriter"] | ||
|
|
||
|
|
@@ -843,6 +844,169 @@ def trigger(self, processingTime=None, once=None, continuous=None): | |
| self._jwrite = self._jwrite.trigger(jTrigger) | ||
| return self | ||
|
|
||
| def foreach(self, f): | ||
| """ | ||
| Sets the output of the streaming query to be processed using the provided writer ``f``. | ||
| This is often used to write the output of a streaming query to arbitrary storage systems. | ||
| The processing logic can be specified in two ways. | ||
|
|
||
| #. A **function** that takes a row as input. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems this variant is specific to Python. I thought we should better match how we support with Scala side.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is superset of what we support in scala. Python users are more likely to use simple lambdas instead of defining classes. But they may also want to write transactional stuff in python with open and close methods. Hence providing both alternatives seems to be a good idea.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (including the response to #21477 (comment)) I kind of agree that it's a-okay idea but I think we usually provide a consistent API support so far unless it's language specific, for example, ContextManager, decorator in Python and etc. Just for clarification, does Scala side support function only support too? Also, I know attribute-checking way is kind of more like "Pythonic" way but I am seeing the documentation is already diverted between Scala vs Python. It costs maintaining overhead on the other hand.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean, we could maybe consider the other ways but wouldn't it better to have the consistent support as the primary, and then see if the other ways are really requested by users? I think we could still incrementally add attribute-checking way or the lambda (or function to be more correct) way later (but we can't in the opposite way).
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Python APIs anyways have slightly divergences from Scala/Java APIs in order to provide better experiences for Python users. For example, Personally, I think we should also add the lambda variant to Scala as well. But that decision for Scala is independent of this PR as there is enough justification for add the lambda variant for
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe
Thing is, it sounded to me like we are kind of prejudging it.. We can't revert it back easily once we go in this way ..
+1 I am okay but I hope this shouldn't be usually done next time .. |
||
| This is a simple way to express your processing logic. Note that this does | ||
| not allow you to deduplicate generated data when failures cause reprocessing of | ||
| some input data. That would require you to specify the processing logic in the next | ||
| way. | ||
|
|
||
| #. An **object** with a ``process`` method and optional ``open`` and ``close`` methods. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @tdas, wouldn't we better just have
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I discussed this with @marmbrus . If there is a ForeachWriter class in python, then uses will have to additionally import it. That's just another overhead that can be avoided by just allowing any class with the appropriate methods. One less step for python users. |
||
| The object can have the following methods. | ||
|
|
||
| * ``open(partition_id, epoch_id)``: *Optional* method that initializes the processing | ||
| (for example, open a connection, start a transaction, etc). Additionally, you can | ||
| use the `partition_id` and `epoch_id` to deduplicate regenerated data | ||
| (discussed later). | ||
|
|
||
| * ``process(row)``: *Non-optional* method that processes each :class:`Row`. | ||
|
|
||
| * ``close(error)``: *Optional* method that finalizes and cleans up (for example, | ||
| close connection, commit transaction, etc.) after all rows have been processed. | ||
|
|
||
| The object will be used by Spark in the following way. | ||
|
|
||
| * A single copy of this object is responsible of all the data generated by a | ||
| single task in a query. In other words, one instance is responsible for | ||
| processing one partition of the data generated in a distributed manner. | ||
|
|
||
| * This object must be serializable because each task will get a fresh | ||
| serialized-deserializedcopy of the provided object. Hence, it is strongly | ||
|
||
| recommended that any initialization for writing data (e.g. opening a | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
| connection or starting a transaction) be done open after the `open(...)` | ||
| method has been called, which signifies that the task is ready to generate data. | ||
|
|
||
| * The lifecycle of the methods are as follows. | ||
|
|
||
| For each partition with ``partition_id``: | ||
|
|
||
| ... For each batch/epoch of streaming data with ``epoch_id``: | ||
|
|
||
| ....... Method ``open(partitionId, epochId)`` is called. | ||
|
|
||
| ....... If ``open(...)`` returns true, for each row in the partition and | ||
| batch/epoch, method ``process(row)`` is called. | ||
|
|
||
| ....... Method ``close(errorOrNull)`` is called with error (if any) seen while | ||
| processing rows. | ||
|
|
||
| Important points to note: | ||
|
|
||
| * The `partitionId` and `epochId` can be used to deduplicate generated data when | ||
| failures cause reprocessing of some input data. This depends on the execution | ||
| mode of the query. If the streaming query is being executed in the micro-batch | ||
| mode, then every partition represented by a unique tuple (partition_id, epoch_id) | ||
| is guaranteed to have the same data. Hence, (partition_id, epoch_id) can be used | ||
| to deduplicate and/or transactionally commit data and achieve exactly-once | ||
| guarantees. However, if the streaming query is being executed in the continuous | ||
| mode, then this guarantee does not hold and therefore should not be used for | ||
| deduplication. | ||
|
|
||
| * The ``close()`` method (if exists) is will be called if `open()` method exists and | ||
|
||
| returns successfully (irrespective of the return value), except if the Python | ||
| crashes in the middle. | ||
|
|
||
| .. note:: Evolving. | ||
|
|
||
| >>> # Print every row using a function | ||
| >>> writer = sdf.writeStream.foreach(lambda x: print(x)) | ||
| >>> # Print every row using a object with process() method | ||
| >>> class RowPrinter: | ||
| ... def open(self, partition_id, epoch_id): | ||
| ... print("Opened %d, %d" % (partition_id, epoch_id)) | ||
| ... return True | ||
| ... def process(self, row): | ||
| ... print(row) | ||
| ... def close(self, error): | ||
| ... print("Closed with error: %s" % str(error)) | ||
| ... | ||
| >>> writer = sdf.writeStream.foreach(RowPrinter()) | ||
| """ | ||
|
|
||
| from pyspark.rdd import _wrap_function | ||
| from pyspark.serializers import PickleSerializer, AutoBatchedSerializer | ||
| from pyspark.taskcontext import TaskContext | ||
|
|
||
| if callable(f): | ||
| """ | ||
| The provided object is a callable function that is supposed to be called on each row. | ||
| Construct a function that takes an iterator and calls the provided function on each row. | ||
| """ | ||
|
||
| def func_without_process(_, iterator): | ||
| for x in iterator: | ||
| f(x) | ||
| return iter([]) | ||
|
|
||
| func = func_without_process | ||
|
|
||
| else: | ||
| """ | ||
| The provided object is not a callable function. Then it is expected to have a | ||
| 'process(row)' method, and optional 'open(partition_id, epoch_id)' and | ||
| 'close(error)' methods. | ||
| """ | ||
|
||
|
|
||
| if not hasattr(f, 'process'): | ||
| raise Exception( | ||
| "Provided object is neither callable nor does it have a 'process' method") | ||
|
|
||
| if not callable(getattr(f, 'process')): | ||
| raise Exception("Attribute 'process' in provided object is not callable") | ||
|
|
||
| open_exists = False | ||
|
||
| if hasattr(f, 'open'): | ||
| if not callable(getattr(f, 'open')): | ||
| raise Exception("Attribute 'open' in provided object is not callable") | ||
| else: | ||
| open_exists = True | ||
|
|
||
| close_exists = False | ||
|
||
| if hasattr(f, "close"): | ||
|
||
| if not callable(getattr(f, 'close')): | ||
| raise Exception("Attribute 'close' in provided object is not callable") | ||
| else: | ||
| close_exists = True | ||
|
|
||
| def func_with_open_process_close(partition_id, iterator): | ||
| epoch_id = TaskContext.get().getLocalProperty('streaming.sql.batchId') | ||
| if epoch_id: | ||
| epoch_id = int(epoch_id) | ||
| else: | ||
| raise Exception("Could not get batch id from TaskContext") | ||
|
|
||
| should_process = True | ||
| if open_exists: | ||
| should_process = f.open(partition_id, epoch_id) | ||
|
|
||
| def call_close_if_needed(error): | ||
| if open_exists and close_exists: | ||
| f.close(error) | ||
| try: | ||
| if should_process: | ||
| for x in iterator: | ||
| f.process(x) | ||
| except Exception as ex: | ||
| call_close_if_needed(ex) | ||
|
||
| raise ex | ||
|
|
||
| call_close_if_needed(None) | ||
| return iter([]) | ||
|
|
||
| func = func_with_open_process_close | ||
|
|
||
| serializer = AutoBatchedSerializer(PickleSerializer()) | ||
| wrapped_func = _wrap_function(self._spark._sc, func, serializer, serializer) | ||
| jForeachWriter = \ | ||
| self._spark._sc._jvm.org.apache.spark.sql.execution.python.PythonForeachWriter( | ||
| wrapped_func, self._df._jdf.schema()) | ||
| self._jwrite.foreach(jForeachWriter) | ||
| return self | ||
|
|
||
| @ignore_unicode_prefix | ||
| @since(2.0) | ||
| def start(self, path=None, format=None, outputMode=None, partitionBy=None, queryName=None, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -296,6 +296,7 @@ def tearDown(self): | |
| # tear down test_bucketed_write state | ||
| self.spark.sql("DROP TABLE IF EXISTS pyspark_bucket") | ||
|
|
||
| ''' | ||
|
||
| def test_row_should_be_read_only(self): | ||
| row = Row(a=1, b=2) | ||
| self.assertEqual(1, row.a) | ||
|
|
@@ -1884,7 +1885,164 @@ def test_query_manager_await_termination(self): | |
| finally: | ||
| q.stop() | ||
| shutil.rmtree(tmpPath) | ||
| ''' | ||
|
|
||
| class ForeachWriterTester: | ||
|
|
||
| def __init__(self, spark): | ||
| self.spark = spark | ||
| self.input_dir = tempfile.mkdtemp() | ||
| self.open_events_dir = tempfile.mkdtemp() | ||
| self.process_events_dir = tempfile.mkdtemp() | ||
| self.close_events_dir = tempfile.mkdtemp() | ||
|
|
||
| def write_open_event(self, partitionId, epochId): | ||
| self._write_event( | ||
| self.open_events_dir, | ||
| {'partition': partitionId, 'epoch': epochId}) | ||
|
|
||
| def write_process_event(self, row): | ||
| self._write_event(self.process_events_dir, {'value': 'text'}) | ||
|
|
||
| def write_close_event(self, error): | ||
| self._write_event(self.close_events_dir, {'error': str(error)}) | ||
|
|
||
| def write_input_file(self): | ||
| self._write_event(self.input_dir, "text") | ||
|
|
||
| def open_events(self): | ||
| return self._read_events(self.open_events_dir, 'partition INT, epoch INT') | ||
|
|
||
| def process_events(self): | ||
| return self._read_events(self.process_events_dir, 'value STRING') | ||
|
|
||
| def close_events(self): | ||
| return self._read_events(self.close_events_dir, 'error STRING') | ||
|
|
||
| def run_streaming_query_on_writer(self, writer, num_files): | ||
| try: | ||
| sdf = self.spark.readStream.format('text').load(self.input_dir) | ||
| sq = sdf.writeStream.foreach(writer).start() | ||
| for i in range(num_files): | ||
| self.write_input_file() | ||
| sq.processAllAvailable() | ||
| sq.stop() | ||
| finally: | ||
| self.stop_all() | ||
|
|
||
| def _read_events(self, dir, json): | ||
| rows = self.spark.read.schema(json).json(dir).collect() | ||
| dicts = [row.asDict() for row in rows] | ||
| return dicts | ||
|
|
||
| def _write_event(self, dir, event): | ||
| import random | ||
| file = open(os.path.join(dir, str(random.randint(0, 100000))), 'w') | ||
|
||
| file.write("%s\n" % str(event)) | ||
| file.close() | ||
|
||
|
|
||
| def stop_all(self): | ||
| for q in self.spark._wrapped.streams.active: | ||
| q.stop() | ||
|
|
||
| def __getstate__(self): | ||
| return (self.open_events_dir, self.process_events_dir, self.close_events_dir) | ||
|
|
||
| def __setstate__(self, state): | ||
| self.open_events_dir, self.process_events_dir, self.close_events_dir = state | ||
|
|
||
| def test_streaming_foreach_with_simple_function(self): | ||
| tester = self.ForeachWriterTester(self.spark) | ||
|
|
||
| def foreach_func(row): | ||
| tester.write_process_event(row) | ||
|
|
||
| tester.run_streaming_query_on_writer(foreach_func, 2) | ||
| self.assertEqual(len(tester.process_events()), 2) | ||
|
|
||
| def test_streaming_foreach_with_basic_open_process_close(self): | ||
| tester = self.ForeachWriterTester(self.spark) | ||
|
|
||
| class ForeachWriter: | ||
| def open(self, partitionId, epochId): | ||
| tester.write_open_event(partitionId, epochId) | ||
| return True | ||
|
|
||
| def process(self, row): | ||
| tester.write_process_event(row) | ||
|
|
||
| def close(self, error): | ||
| tester.write_close_event(error) | ||
|
|
||
| tester.run_streaming_query_on_writer(ForeachWriter(), 2) | ||
|
|
||
| open_events = tester.open_events() | ||
| self.assertEqual(len(open_events), 2) | ||
| self.assertSetEqual(set([e['epoch'] for e in open_events]), {0, 1}) | ||
|
|
||
| self.assertEqual(len(tester.process_events()), 2) | ||
|
|
||
| close_events = tester.close_events() | ||
| self.assertEqual(len(close_events), 2) | ||
| self.assertSetEqual(set([e['error'] for e in close_events]), {'None'}) | ||
|
|
||
| def test_streaming_foreach_with_open_returning_false(self): | ||
| tester = self.ForeachWriterTester(self.spark) | ||
|
|
||
| class ForeachWriter: | ||
| def open(self, partitionId, epochId): | ||
| tester.write_open_event(partitionId, epochId) | ||
| return False | ||
|
|
||
| def process(self, row): | ||
| tester.write_process_event(row) | ||
|
|
||
| def close(self, error): | ||
| tester.write_close_event(error) | ||
|
|
||
| tester.run_streaming_query_on_writer(ForeachWriter(), 2) | ||
|
|
||
| self.assertEqual(len(tester.open_events()), 2) | ||
| self.assertEqual(len(tester.process_events()), 0) # no row was processed | ||
|
||
| close_events = tester.close_events() | ||
| self.assertEqual(len(close_events), 2) | ||
| self.assertSetEqual(set([e['error'] for e in close_events]), {'None'}) | ||
|
|
||
| def test_streaming_foreach_with_process_throwing_error(self): | ||
| from pyspark.sql.utils import StreamingQueryException | ||
|
|
||
| tester = self.ForeachWriterTester(self.spark) | ||
|
|
||
| class ForeachWriter: | ||
| def open(self, partitionId, epochId): | ||
| tester.write_open_event(partitionId, epochId) | ||
| return True | ||
|
|
||
| def process(self, row): | ||
| raise Exception("test error") | ||
|
|
||
| def close(self, error): | ||
| tester.write_close_event(error) | ||
|
|
||
| try: | ||
| sdf = self.spark.readStream.format('text').load(tester.input_dir) | ||
| sq = sdf.writeStream.foreach(ForeachWriter()).start() | ||
| tester.write_input_file() | ||
| sq.processAllAvailable() | ||
| self.fail("bad writer should fail the query") # this is not expected | ||
| except StreamingQueryException as e: | ||
| # self.assertTrue("test error" in e.desc) # this is expected | ||
|
||
| pass | ||
| finally: | ||
| tester.stop_all() | ||
|
|
||
| self.assertEqual(len(tester.open_events()), 1) | ||
| self.assertEqual(len(tester.process_events()), 0) # no row was processed | ||
| close_events = tester.close_events() | ||
| self.assertEqual(len(close_events), 1) | ||
| # self.assertTrue("test error" in e[0]['error']) | ||
|
||
|
|
||
| ''' | ||
|
||
| def test_help_command(self): | ||
| # Regression test for SPARK-5464 | ||
| rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) | ||
|
|
@@ -5391,7 +5549,7 @@ def test_invalid_args(self): | |
| AnalysisException, | ||
| 'mixture.*aggregate function.*group aggregate pandas UDF'): | ||
| df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect() | ||
|
|
||
| ''' | ||
| if __name__ == "__main__": | ||
| from pyspark.sql.tests import * | ||
| if xmlrunner: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is only to speed up local testing. Will remove this.