-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-23380][PYTHON] Adds a conf for Arrow fallback in toPandas/createDataFrame with Pandas DataFrame #20678
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
7f87d25
7641fd0
cfb08a1
229a5f7
ed30c20
af60cb7
b5bea82
4ccaa81
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1689,6 +1689,10 @@ using the call `toPandas()` and when creating a Spark DataFrame from a Pandas Da | |
| `createDataFrame(pandas_df)`. To use Arrow when executing these calls, users need to first set | ||
| the Spark configuration 'spark.sql.execution.arrow.enabled' to 'true'. This is disabled by default. | ||
|
|
||
| In addition, optimizations enabled by 'spark.sql.execution.arrow.enabled' will fallback automatically | ||
| to non-optimized implementations if an error occurs. This can be controlled by | ||
| 'spark.sql.execution.arrow.fallback.enabled'. | ||
|
|
||
| <div class="codetabs"> | ||
| <div data-lang="python" markdown="1"> | ||
| {% include_example dataframe_with_arrow python/sql/arrow.py %} | ||
|
|
@@ -1800,6 +1804,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see | |
| ## Upgrading From Spark SQL 2.3 to 2.4 | ||
|
|
||
| - Since Spark 2.4, Spark maximizes the usage of a vectorized ORC reader for ORC files by default. To do that, `spark.sql.orc.impl` and `spark.sql.orc.filterPushdown` change their default values to `native` and `true` respectively. | ||
| - In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unabled to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched by `spark.sql.execution.arrow.fallback.enabled`. | ||
|
||
|
|
||
| ## Upgrading From Spark SQL 2.2 to 2.3 | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1986,55 +1986,89 @@ def toPandas(self): | |
| timezone = None | ||
|
|
||
| if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true": | ||
| should_fallback = False | ||
|
||
| try: | ||
| from pyspark.sql.types import _check_dataframe_convert_date, \ | ||
| _check_dataframe_localize_timestamps, to_arrow_schema | ||
| from pyspark.sql.types import to_arrow_schema | ||
| from pyspark.sql.utils import require_minimum_pyarrow_version | ||
|
|
||
| require_minimum_pyarrow_version() | ||
| import pyarrow | ||
| to_arrow_schema(self.schema) | ||
| tables = self._collectAsArrow() | ||
| if tables: | ||
| table = pyarrow.concat_tables(tables) | ||
| pdf = table.to_pandas() | ||
| pdf = _check_dataframe_convert_date(pdf, self.schema) | ||
| return _check_dataframe_localize_timestamps(pdf, timezone) | ||
| else: | ||
| return pd.DataFrame.from_records([], columns=self.columns) | ||
| except Exception as e: | ||
| msg = ( | ||
| "Note: toPandas attempted Arrow optimization because " | ||
| "'spark.sql.execution.arrow.enabled' is set to true. Please set it to false " | ||
| "to disable this.") | ||
| raise RuntimeError("%s\n%s" % (_exception_message(e), msg)) | ||
| else: | ||
| pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) | ||
|
|
||
| dtype = {} | ||
| if self.sql_ctx.getConf("spark.sql.execution.arrow.fallback.enabled", "true") \ | ||
| .lower() == "true": | ||
| msg = ( | ||
| "toPandas attempted Arrow optimization because " | ||
| "'spark.sql.execution.arrow.enabled' is set to true; however, " | ||
| "failed by the reason below:\n %s\n" | ||
| "Attempts non-optimization as " | ||
| "'spark.sql.execution.arrow.fallback.enabled' is set to " | ||
| "true." % _exception_message(e)) | ||
| warnings.warn(msg) | ||
| should_fallback = True | ||
| else: | ||
| msg = ( | ||
| "toPandas attempted Arrow optimization because " | ||
| "'spark.sql.execution.arrow.enabled' is set to true; however, " | ||
| "failed by the reason below:\n %s\n" | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hm ... I tried to like make a |
||
| "For fallback to non-optimization automatically, please set true to " | ||
| "'spark.sql.execution.arrow.fallback.enabled'." % _exception_message(e)) | ||
| raise RuntimeError(msg) | ||
|
|
||
| if not should_fallback: | ||
|
||
| try: | ||
| from pyspark.sql.types import _check_dataframe_convert_date, \ | ||
| _check_dataframe_localize_timestamps | ||
| import pyarrow | ||
|
|
||
| tables = self._collectAsArrow() | ||
| if tables: | ||
| table = pyarrow.concat_tables(tables) | ||
| pdf = table.to_pandas() | ||
| pdf = _check_dataframe_convert_date(pdf, self.schema) | ||
| return _check_dataframe_localize_timestamps(pdf, timezone) | ||
| else: | ||
| return pd.DataFrame.from_records([], columns=self.columns) | ||
| except Exception as e: | ||
| # We might have to allow fallback here as well but multiple Spark jobs can | ||
| # be executed. So, simply fail in this case for now. | ||
| msg = ( | ||
| "toPandas attempted Arrow optimization because " | ||
| "'spark.sql.execution.arrow.enabled' is set to true; however, " | ||
| "failed unexpectedly:\n %s\n" | ||
| "Note that 'spark.sql.execution.arrow.fallback.enabled' does " | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 good job having this explanation in the exception |
||
| "not have an effect in such failure in the middle of " | ||
| "computation." % _exception_message(e)) | ||
| raise RuntimeError(msg) | ||
|
|
||
| # Below is toPandas without Arrow optimization. | ||
|
||
| pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) | ||
|
|
||
| dtype = {} | ||
| for field in self.schema: | ||
| pandas_type = _to_corrected_pandas_type(field.dataType) | ||
| # SPARK-21766: if an integer field is nullable and has null values, it can be | ||
| # inferred by pandas as float column. Once we convert the column with NaN back | ||
| # to integer type e.g., np.int16, we will hit exception. So we use the inferred | ||
| # float type, not the corrected type from the schema in this case. | ||
| if pandas_type is not None and \ | ||
| not(isinstance(field.dataType, IntegralType) and field.nullable and | ||
| pdf[field.name].isnull().any()): | ||
| dtype[field.name] = pandas_type | ||
|
|
||
| for f, t in dtype.items(): | ||
| pdf[f] = pdf[f].astype(t, copy=False) | ||
|
|
||
| if timezone is None: | ||
| return pdf | ||
| else: | ||
| from pyspark.sql.types import _check_series_convert_timestamps_local_tz | ||
| for field in self.schema: | ||
| pandas_type = _to_corrected_pandas_type(field.dataType) | ||
| # SPARK-21766: if an integer field is nullable and has null values, it can be | ||
| # inferred by pandas as float column. Once we convert the column with NaN back | ||
| # to integer type e.g., np.int16, we will hit exception. So we use the inferred | ||
| # float type, not the corrected type from the schema in this case. | ||
| if pandas_type is not None and \ | ||
| not(isinstance(field.dataType, IntegralType) and field.nullable and | ||
| pdf[field.name].isnull().any()): | ||
| dtype[field.name] = pandas_type | ||
|
|
||
| for f, t in dtype.items(): | ||
| pdf[f] = pdf[f].astype(t, copy=False) | ||
|
|
||
| if timezone is None: | ||
| return pdf | ||
| else: | ||
| from pyspark.sql.types import _check_series_convert_timestamps_local_tz | ||
| for field in self.schema: | ||
| # TODO: handle nested timestamps, such as ArrayType(TimestampType())? | ||
| if isinstance(field.dataType, TimestampType): | ||
| pdf[field.name] = \ | ||
| _check_series_convert_timestamps_local_tz(pdf[field.name], timezone) | ||
| return pdf | ||
| # TODO: handle nested timestamps, such as ArrayType(TimestampType())? | ||
| if isinstance(field.dataType, TimestampType): | ||
| pdf[field.name] = \ | ||
| _check_series_convert_timestamps_local_tz(pdf[field.name], timezone) | ||
| return pdf | ||
|
|
||
| def _collectAsArrow(self): | ||
| """ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,7 +32,9 @@ | |
| import datetime | ||
| import array | ||
| import ctypes | ||
| import warnings | ||
| import py4j | ||
| from contextlib import contextmanager | ||
|
|
||
| try: | ||
| import xmlrunner | ||
|
|
@@ -48,12 +50,13 @@ | |
| else: | ||
| import unittest | ||
|
|
||
| from pyspark.util import _exception_message | ||
|
|
||
| _pandas_requirement_message = None | ||
| try: | ||
| from pyspark.sql.utils import require_minimum_pandas_version | ||
| require_minimum_pandas_version() | ||
| except ImportError as e: | ||
| from pyspark.util import _exception_message | ||
| # If Pandas version requirement is not satisfied, skip related tests. | ||
| _pandas_requirement_message = _exception_message(e) | ||
|
|
||
|
|
@@ -62,7 +65,6 @@ | |
| from pyspark.sql.utils import require_minimum_pyarrow_version | ||
| require_minimum_pyarrow_version() | ||
| except ImportError as e: | ||
| from pyspark.util import _exception_message | ||
| # If Arrow version requirement is not satisfied, skip related tests. | ||
| _pyarrow_requirement_message = _exception_message(e) | ||
|
|
||
|
|
@@ -195,6 +197,23 @@ def tearDownClass(cls): | |
| ReusedPySparkTestCase.tearDownClass() | ||
| cls.spark.stop() | ||
|
|
||
| @contextmanager | ||
| def sql_conf(self, key, value): | ||
| """ | ||
| A convenient context manager to test some configuration specific logic. This sets | ||
| `value` to the configuration `key` and then restores it back when it exits. | ||
| """ | ||
|
|
||
| orig_value = self.spark.conf.get(key, None) | ||
| self.spark.conf.set(key, value) | ||
| try: | ||
| yield | ||
| finally: | ||
| if orig_value is None: | ||
| self.spark.conf.unset(key) | ||
| else: | ||
| self.spark.conf.set(key, orig_value) | ||
|
|
||
| def assertPandasEqual(self, expected, result): | ||
| msg = ("DataFrames are not equal: " + | ||
| "\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) + | ||
|
|
@@ -3458,6 +3477,8 @@ def setUpClass(cls): | |
|
|
||
| cls.spark.conf.set("spark.sql.session.timeZone", tz) | ||
| cls.spark.conf.set("spark.sql.execution.arrow.enabled", "true") | ||
| # Disable fallback by default to easily detect the failures. | ||
| cls.spark.conf.set("spark.sql.execution.arrow.fallback.enabled", "false") | ||
| cls.schema = StructType([ | ||
| StructField("1_str_t", StringType(), True), | ||
| StructField("2_int_t", IntegerType(), True), | ||
|
|
@@ -3493,19 +3514,30 @@ def create_pandas_data_frame(self): | |
| data_dict["4_float_t"] = np.float32(data_dict["4_float_t"]) | ||
| return pd.DataFrame(data=data_dict) | ||
|
|
||
| def test_unsupported_datatype(self): | ||
| schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) | ||
| df = self.spark.createDataFrame([(None,)], schema=schema) | ||
| with QuietTest(self.sc): | ||
| with self.assertRaisesRegexp(Exception, 'Unsupported type'): | ||
| df.toPandas() | ||
| def test_toPandas_fallback_enabled(self): | ||
| import pandas as pd | ||
|
|
||
| df = self.spark.createDataFrame([(None,)], schema="a binary") | ||
| with QuietTest(self.sc): | ||
| with self.assertRaisesRegexp( | ||
| Exception, | ||
| 'Unsupported type.*\nNote: toPandas attempted Arrow optimization because'): | ||
| df.toPandas() | ||
| with self.sql_conf("spark.sql.execution.arrow.fallback.enabled", True): | ||
| schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) | ||
| df = self.spark.createDataFrame([({u'a': 1},)], schema=schema) | ||
| with QuietTest(self.sc): | ||
| with warnings.catch_warnings(record=True) as warns: | ||
| pdf = df.toPandas() | ||
| # Catch and check the last UserWarning. | ||
| user_warns = [ | ||
| warn.message for warn in warns if isinstance(warn.message, UserWarning)] | ||
| self.assertTrue(len(user_warns) > 0) | ||
| self.assertTrue( | ||
| "Attempts non-optimization" in _exception_message(user_warns[-1])) | ||
| self.assertPandasEqual(pdf, pd.DataFrame({u'map': [{u'a': 1}]})) | ||
|
|
||
| def test_toPandas_fallback_disabled(self): | ||
| with self.sql_conf("spark.sql.execution.arrow.fallback.enabled", False): | ||
|
||
| schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) | ||
| df = self.spark.createDataFrame([(None,)], schema=schema) | ||
| with QuietTest(self.sc): | ||
| with self.assertRaisesRegexp(Exception, 'Unsupported type'): | ||
| df.toPandas() | ||
|
|
||
| def test_null_conversion(self): | ||
| df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] + | ||
|
|
@@ -3625,7 +3657,7 @@ def test_createDataFrame_with_incorrect_schema(self): | |
| pdf = self.create_pandas_data_frame() | ||
| wrong_schema = StructType(list(reversed(self.schema))) | ||
| with QuietTest(self.sc): | ||
| with self.assertRaisesRegexp(TypeError, ".*field.*can.not.accept.*type"): | ||
| with self.assertRaisesRegexp(RuntimeError, ".*No cast.*string.*timestamp.*"): | ||
| self.spark.createDataFrame(pdf, schema=wrong_schema) | ||
|
|
||
| def test_createDataFrame_with_names(self): | ||
|
|
@@ -3650,7 +3682,7 @@ def test_createDataFrame_column_name_encoding(self): | |
| def test_createDataFrame_with_single_data_type(self): | ||
| import pandas as pd | ||
| with QuietTest(self.sc): | ||
| with self.assertRaisesRegexp(TypeError, ".*IntegerType.*tuple"): | ||
| with self.assertRaisesRegexp(RuntimeError, ".*IntegerType.*not supported.*"): | ||
| self.spark.createDataFrame(pd.DataFrame({"a": [1]}), schema="int") | ||
|
|
||
| def test_createDataFrame_does_not_modify_input(self): | ||
|
|
@@ -3705,6 +3737,31 @@ def test_createDataFrame_with_int_col_names(self): | |
| self.assertEqual(pdf_col_names, df.columns) | ||
| self.assertEqual(pdf_col_names, df_arrow.columns) | ||
|
|
||
| def test_createDataFrame_fallback_enabled(self): | ||
| import pandas as pd | ||
|
|
||
| with QuietTest(self.sc): | ||
| with self.sql_conf("spark.sql.execution.arrow.fallback.enabled", True): | ||
| with warnings.catch_warnings(record=True) as warns: | ||
| df = self.spark.createDataFrame( | ||
| pd.DataFrame([[{u'a': 1}]]), "a: map<string, int>") | ||
| # Catch and check the last UserWarning. | ||
| user_warns = [ | ||
| warn.message for warn in warns if isinstance(warn.message, UserWarning)] | ||
| self.assertTrue(len(user_warns) > 0) | ||
| self.assertTrue( | ||
| "Attempts non-optimization" in _exception_message(user_warns[-1])) | ||
| self.assertEqual(df.collect(), [Row(a={u'a': 1})]) | ||
|
|
||
| def test_createDataFrame_fallback_disabled(self): | ||
| import pandas as pd | ||
|
|
||
| with QuietTest(self.sc): | ||
| with self.sql_conf("spark.sql.execution.arrow.fallback.enabled", False): | ||
| with self.assertRaisesRegexp(Exception, 'Unsupported type'): | ||
| self.spark.createDataFrame( | ||
| pd.DataFrame([[{u'a': 1}]]), "a: map<string, int>") | ||
|
|
||
| # Regression test for SPARK-23314 | ||
| def test_timestamp_dst(self): | ||
| import pandas as pd | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1058,7 +1058,7 @@ object SQLConf { | |
| .intConf | ||
| .createWithDefault(100) | ||
|
|
||
| val ARROW_EXECUTION_ENABLE = | ||
| val ARROW_EXECUTION_ENABLED = | ||
| buildConf("spark.sql.execution.arrow.enabled") | ||
| .doc("When true, make use of Apache Arrow for columnar data transfers. Currently available " + | ||
| "for use with pyspark.sql.DataFrame.toPandas, and " + | ||
|
|
@@ -1068,6 +1068,13 @@ object SQLConf { | |
| .booleanConf | ||
| .createWithDefault(false) | ||
|
|
||
| val ARROW_FALLBACK_ENABLED = | ||
| buildConf("spark.sql.execution.arrow.fallback.enabled") | ||
| .doc("When true, optimizations enabled by 'spark.sql.execution.arrow.enabled' will " + | ||
| "fallback automatically to non-optimized implementations if an error occurs.") | ||
| .booleanConf | ||
| .createWithDefault(true) | ||
|
|
||
| val ARROW_EXECUTION_MAX_RECORDS_PER_BATCH = | ||
| buildConf("spark.sql.execution.arrow.maxRecordsPerBatch") | ||
| .doc("When using Apache Arrow, limit the maximum number of records that can be written " + | ||
|
|
@@ -1518,7 +1525,9 @@ class SQLConf extends Serializable with Logging { | |
|
|
||
| def rangeExchangeSampleSizePerPartition: Int = getConf(RANGE_EXCHANGE_SAMPLE_SIZE_PER_PARTITION) | ||
|
|
||
| def arrowEnable: Boolean = getConf(ARROW_EXECUTION_ENABLE) | ||
| def arrowEnable: Boolean = getConf(ARROW_EXECUTION_ENABLED) | ||
|
||
|
|
||
| def arrowFallbackEnable: Boolean = getConf(ARROW_FALLBACK_ENABLED) | ||
|
||
|
|
||
| def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So we need to be clear that we only do this if an error occurs in schema parsing, not any error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me try to rephrase this doc a bit. The point I was trying to make in this fallback (for now) was, to only do the fallback before the actual distributed computation within Spark.