Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/sql-programming-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -1689,6 +1689,10 @@ using the call `toPandas()` and when creating a Spark DataFrame from a Pandas Da
`createDataFrame(pandas_df)`. To use Arrow when executing these calls, users need to first set
the Spark configuration 'spark.sql.execution.arrow.enabled' to 'true'. This is disabled by default.

In addition, optimizations enabled by 'spark.sql.execution.arrow.enabled' will fallback automatically
to non-optimized implementations if an error occurs. This can be controlled by
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we need to be clear that we only do this if an error occurs in schema parsing, not any error.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me try to rephrase this doc a bit. The point I was trying to make in this fallback (for now) was, to only do the fallback before the actual distributed computation within Spark.

'spark.sql.execution.arrow.fallback.enabled'.

<div class="codetabs">
<div data-lang="python" markdown="1">
{% include_example dataframe_with_arrow python/sql/arrow.py %}
Expand Down Expand Up @@ -1800,6 +1804,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see
## Upgrading From Spark SQL 2.3 to 2.4

- Since Spark 2.4, Spark maximizes the usage of a vectorized ORC reader for ORC files by default. To do that, `spark.sql.orc.impl` and `spark.sql.orc.filterPushdown` change their default values to `native` and `true` respectively.
- In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unabled to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched by `spark.sql.execution.arrow.fallback.enabled`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not only in migration section, I think we should also document this config in the section like PySpark Usage Guide for Pandas with Apache Arrow.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, added.

Copy link
Member

@felixcheung felixcheung Mar 7, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which can be switched by -> which can be switched off by or which can be switched off with or which can be turned off with


## Upgrading From Spark SQL 2.2 to 2.3

Expand Down
118 changes: 76 additions & 42 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1986,55 +1986,89 @@ def toPandas(self):
timezone = None

if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true":
should_fallback = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This variable name is a little confusing to me while I'm tracing the code. How about "use_arrow" and swap the meanings? Because right now if a user doesn't have arrow enabled we skip the arrow conversion because of the value of should_fallback which seems.... odd.

try:
from pyspark.sql.types import _check_dataframe_convert_date, \
_check_dataframe_localize_timestamps, to_arrow_schema
from pyspark.sql.types import to_arrow_schema
from pyspark.sql.utils import require_minimum_pyarrow_version

require_minimum_pyarrow_version()
import pyarrow
to_arrow_schema(self.schema)
tables = self._collectAsArrow()
if tables:
table = pyarrow.concat_tables(tables)
pdf = table.to_pandas()
pdf = _check_dataframe_convert_date(pdf, self.schema)
return _check_dataframe_localize_timestamps(pdf, timezone)
else:
return pd.DataFrame.from_records([], columns=self.columns)
except Exception as e:
msg = (
"Note: toPandas attempted Arrow optimization because "
"'spark.sql.execution.arrow.enabled' is set to true. Please set it to false "
"to disable this.")
raise RuntimeError("%s\n%s" % (_exception_message(e), msg))
else:
pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns)

dtype = {}
if self.sql_ctx.getConf("spark.sql.execution.arrow.fallback.enabled", "true") \
.lower() == "true":
msg = (
"toPandas attempted Arrow optimization because "
"'spark.sql.execution.arrow.enabled' is set to true; however, "
"failed by the reason below:\n %s\n"
"Attempts non-optimization as "
"'spark.sql.execution.arrow.fallback.enabled' is set to "
"true." % _exception_message(e))
warnings.warn(msg)
should_fallback = True
else:
msg = (
"toPandas attempted Arrow optimization because "
"'spark.sql.execution.arrow.enabled' is set to true; however, "
"failed by the reason below:\n %s\n"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

toPandas attempted Arrow optimization because... repeats three times here, maybe we can dedup it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm ... I tried to like make a "toPandas attempted Arrow optimization because ... %s" and reuse it but seems a little bit overkill.

"For fallback to non-optimization automatically, please set true to "
"'spark.sql.execution.arrow.fallback.enabled'." % _exception_message(e))
raise RuntimeError(msg)

if not should_fallback:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if I'm tracing the logic correctly, if arrow optimizations are enabled and there is an exception parsing the schema and we don't have fall back enabled we go down this code path or if we don't have arrow enabled we also go down this code path? It might make sense to add a comment here with what the intended times to go down this path are?

Copy link
Member Author

@HyukjinKwon HyukjinKwon Feb 28, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, but there's one more - we fallback if PyArrow is not installed (or version is different). Will add some comments to make this easier to read.

try:
from pyspark.sql.types import _check_dataframe_convert_date, \
_check_dataframe_localize_timestamps
import pyarrow

tables = self._collectAsArrow()
if tables:
table = pyarrow.concat_tables(tables)
pdf = table.to_pandas()
pdf = _check_dataframe_convert_date(pdf, self.schema)
return _check_dataframe_localize_timestamps(pdf, timezone)
else:
return pd.DataFrame.from_records([], columns=self.columns)
except Exception as e:
# We might have to allow fallback here as well but multiple Spark jobs can
# be executed. So, simply fail in this case for now.
msg = (
"toPandas attempted Arrow optimization because "
"'spark.sql.execution.arrow.enabled' is set to true; however, "
"failed unexpectedly:\n %s\n"
"Note that 'spark.sql.execution.arrow.fallback.enabled' does "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 good job having this explanation in the exception

"not have an effect in such failure in the middle of "
"computation." % _exception_message(e))
raise RuntimeError(msg)

# Below is toPandas without Arrow optimization.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise, the change from here is due to removed else: block.

pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns)

dtype = {}
for field in self.schema:
pandas_type = _to_corrected_pandas_type(field.dataType)
# SPARK-21766: if an integer field is nullable and has null values, it can be
# inferred by pandas as float column. Once we convert the column with NaN back
# to integer type e.g., np.int16, we will hit exception. So we use the inferred
# float type, not the corrected type from the schema in this case.
if pandas_type is not None and \
not(isinstance(field.dataType, IntegralType) and field.nullable and
pdf[field.name].isnull().any()):
dtype[field.name] = pandas_type

for f, t in dtype.items():
pdf[f] = pdf[f].astype(t, copy=False)

if timezone is None:
return pdf
else:
from pyspark.sql.types import _check_series_convert_timestamps_local_tz
for field in self.schema:
pandas_type = _to_corrected_pandas_type(field.dataType)
# SPARK-21766: if an integer field is nullable and has null values, it can be
# inferred by pandas as float column. Once we convert the column with NaN back
# to integer type e.g., np.int16, we will hit exception. So we use the inferred
# float type, not the corrected type from the schema in this case.
if pandas_type is not None and \
not(isinstance(field.dataType, IntegralType) and field.nullable and
pdf[field.name].isnull().any()):
dtype[field.name] = pandas_type

for f, t in dtype.items():
pdf[f] = pdf[f].astype(t, copy=False)

if timezone is None:
return pdf
else:
from pyspark.sql.types import _check_series_convert_timestamps_local_tz
for field in self.schema:
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if isinstance(field.dataType, TimestampType):
pdf[field.name] = \
_check_series_convert_timestamps_local_tz(pdf[field.name], timezone)
return pdf
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if isinstance(field.dataType, TimestampType):
pdf[field.name] = \
_check_series_convert_timestamps_local_tz(pdf[field.name], timezone)
return pdf

def _collectAsArrow(self):
"""
Expand Down
22 changes: 20 additions & 2 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,8 +666,26 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr
try:
return self._create_from_pandas_with_arrow(data, schema, timezone)
except Exception as e:
warnings.warn("Arrow will not be used in createDataFrame: %s" % str(e))
# Fallback to create DataFrame without arrow if raise some exception
from pyspark.util import _exception_message

if self.conf.get("spark.sql.execution.arrow.fallback.enabled", "true") \
.lower() == "true":
msg = (
"createDataFrame attempted Arrow optimization because "
"'spark.sql.execution.arrow.enabled' is set to true; however, "
"failed by the reason below:\n %s\n"
"Attempts non-optimization as "
"'spark.sql.execution.arrow.fallback.enabled' is set to "
"true." % _exception_message(e))
warnings.warn(msg)
else:
msg = (
"createDataFrame attempted Arrow optimization because "
"'spark.sql.execution.arrow.enabled' is set to true; however, "
"failed by the reason below:\n %s\n"
"For fallback to non-optimization automatically, please set true to "
"'spark.sql.execution.arrow.fallback.enabled'." % _exception_message(e))
raise RuntimeError(msg)
data = self._convert_from_pandas(data, schema, timezone)

if isinstance(schema, StructType):
Expand Down
89 changes: 73 additions & 16 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
import datetime
import array
import ctypes
import warnings
import py4j
from contextlib import contextmanager

try:
import xmlrunner
Expand All @@ -48,12 +50,13 @@
else:
import unittest

from pyspark.util import _exception_message

_pandas_requirement_message = None
try:
from pyspark.sql.utils import require_minimum_pandas_version
require_minimum_pandas_version()
except ImportError as e:
from pyspark.util import _exception_message
# If Pandas version requirement is not satisfied, skip related tests.
_pandas_requirement_message = _exception_message(e)

Expand All @@ -62,7 +65,6 @@
from pyspark.sql.utils import require_minimum_pyarrow_version
require_minimum_pyarrow_version()
except ImportError as e:
from pyspark.util import _exception_message
# If Arrow version requirement is not satisfied, skip related tests.
_pyarrow_requirement_message = _exception_message(e)

Expand Down Expand Up @@ -195,6 +197,23 @@ def tearDownClass(cls):
ReusedPySparkTestCase.tearDownClass()
cls.spark.stop()

@contextmanager
def sql_conf(self, key, value):
"""
A convenient context manager to test some configuration specific logic. This sets
`value` to the configuration `key` and then restores it back when it exits.
"""

orig_value = self.spark.conf.get(key, None)
self.spark.conf.set(key, value)
try:
yield
finally:
if orig_value is None:
self.spark.conf.unset(key)
else:
self.spark.conf.set(key, orig_value)

def assertPandasEqual(self, expected, result):
msg = ("DataFrames are not equal: " +
"\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) +
Expand Down Expand Up @@ -3458,6 +3477,8 @@ def setUpClass(cls):

cls.spark.conf.set("spark.sql.session.timeZone", tz)
cls.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
# Disable fallback by default to easily detect the failures.
cls.spark.conf.set("spark.sql.execution.arrow.fallback.enabled", "false")
cls.schema = StructType([
StructField("1_str_t", StringType(), True),
StructField("2_int_t", IntegerType(), True),
Expand Down Expand Up @@ -3493,19 +3514,30 @@ def create_pandas_data_frame(self):
data_dict["4_float_t"] = np.float32(data_dict["4_float_t"])
return pd.DataFrame(data=data_dict)

def test_unsupported_datatype(self):
schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
df = self.spark.createDataFrame([(None,)], schema=schema)
with QuietTest(self.sc):
with self.assertRaisesRegexp(Exception, 'Unsupported type'):
df.toPandas()
def test_toPandas_fallback_enabled(self):
import pandas as pd

df = self.spark.createDataFrame([(None,)], schema="a binary")
with QuietTest(self.sc):
with self.assertRaisesRegexp(
Exception,
'Unsupported type.*\nNote: toPandas attempted Arrow optimization because'):
df.toPandas()
with self.sql_conf("spark.sql.execution.arrow.fallback.enabled", True):
schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
df = self.spark.createDataFrame([({u'a': 1},)], schema=schema)
with QuietTest(self.sc):
with warnings.catch_warnings(record=True) as warns:
pdf = df.toPandas()
# Catch and check the last UserWarning.
user_warns = [
warn.message for warn in warns if isinstance(warn.message, UserWarning)]
self.assertTrue(len(user_warns) > 0)
self.assertTrue(
"Attempts non-optimization" in _exception_message(user_warns[-1]))
self.assertPandasEqual(pdf, pd.DataFrame({u'map': [{u'a': 1}]}))

def test_toPandas_fallback_disabled(self):
with self.sql_conf("spark.sql.execution.arrow.fallback.enabled", False):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @ueshin and @BryanCutler, do you guys like this idea?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems good, but how about using dict for setting multiple configs at the same time?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, I was thinking that too. I took a quick look for the rest of tests and seems we are fine with a single pair for now. Will fix it as so in place in the future if you are okay with that too.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, good idea! +1 on using a dict

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will fix it for using a dict here soon.

schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
df = self.spark.createDataFrame([(None,)], schema=schema)
with QuietTest(self.sc):
with self.assertRaisesRegexp(Exception, 'Unsupported type'):
df.toPandas()

def test_null_conversion(self):
df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] +
Expand Down Expand Up @@ -3625,7 +3657,7 @@ def test_createDataFrame_with_incorrect_schema(self):
pdf = self.create_pandas_data_frame()
wrong_schema = StructType(list(reversed(self.schema)))
with QuietTest(self.sc):
with self.assertRaisesRegexp(TypeError, ".*field.*can.not.accept.*type"):
with self.assertRaisesRegexp(RuntimeError, ".*No cast.*string.*timestamp.*"):
self.spark.createDataFrame(pdf, schema=wrong_schema)

def test_createDataFrame_with_names(self):
Expand All @@ -3650,7 +3682,7 @@ def test_createDataFrame_column_name_encoding(self):
def test_createDataFrame_with_single_data_type(self):
import pandas as pd
with QuietTest(self.sc):
with self.assertRaisesRegexp(TypeError, ".*IntegerType.*tuple"):
with self.assertRaisesRegexp(RuntimeError, ".*IntegerType.*not supported.*"):
self.spark.createDataFrame(pd.DataFrame({"a": [1]}), schema="int")

def test_createDataFrame_does_not_modify_input(self):
Expand Down Expand Up @@ -3705,6 +3737,31 @@ def test_createDataFrame_with_int_col_names(self):
self.assertEqual(pdf_col_names, df.columns)
self.assertEqual(pdf_col_names, df_arrow.columns)

def test_createDataFrame_fallback_enabled(self):
import pandas as pd

with QuietTest(self.sc):
with self.sql_conf("spark.sql.execution.arrow.fallback.enabled", True):
with warnings.catch_warnings(record=True) as warns:
df = self.spark.createDataFrame(
pd.DataFrame([[{u'a': 1}]]), "a: map<string, int>")
# Catch and check the last UserWarning.
user_warns = [
warn.message for warn in warns if isinstance(warn.message, UserWarning)]
self.assertTrue(len(user_warns) > 0)
self.assertTrue(
"Attempts non-optimization" in _exception_message(user_warns[-1]))
self.assertEqual(df.collect(), [Row(a={u'a': 1})])

def test_createDataFrame_fallback_disabled(self):
import pandas as pd

with QuietTest(self.sc):
with self.sql_conf("spark.sql.execution.arrow.fallback.enabled", False):
with self.assertRaisesRegexp(Exception, 'Unsupported type'):
self.spark.createDataFrame(
pd.DataFrame([[{u'a': 1}]]), "a: map<string, int>")

# Regression test for SPARK-23314
def test_timestamp_dst(self):
import pandas as pd
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1058,7 +1058,7 @@ object SQLConf {
.intConf
.createWithDefault(100)

val ARROW_EXECUTION_ENABLE =
val ARROW_EXECUTION_ENABLED =
buildConf("spark.sql.execution.arrow.enabled")
.doc("When true, make use of Apache Arrow for columnar data transfers. Currently available " +
"for use with pyspark.sql.DataFrame.toPandas, and " +
Expand All @@ -1068,6 +1068,13 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val ARROW_FALLBACK_ENABLED =
buildConf("spark.sql.execution.arrow.fallback.enabled")
.doc("When true, optimizations enabled by 'spark.sql.execution.arrow.enabled' will " +
"fallback automatically to non-optimized implementations if an error occurs.")
.booleanConf
.createWithDefault(true)

val ARROW_EXECUTION_MAX_RECORDS_PER_BATCH =
buildConf("spark.sql.execution.arrow.maxRecordsPerBatch")
.doc("When using Apache Arrow, limit the maximum number of records that can be written " +
Expand Down Expand Up @@ -1518,7 +1525,9 @@ class SQLConf extends Serializable with Logging {

def rangeExchangeSampleSizePerPartition: Int = getConf(RANGE_EXCHANGE_SAMPLE_SIZE_PER_PARTITION)

def arrowEnable: Boolean = getConf(ARROW_EXECUTION_ENABLE)
def arrowEnable: Boolean = getConf(ARROW_EXECUTION_ENABLED)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually seems we don't use arrowEnable too.


def arrowFallbackEnable: Boolean = getConf(ARROW_FALLBACK_ENABLED)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Have we used this arrowFallbackEnable definition?


def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH)

Expand Down