-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-13233][SQL] Python Dataset (basic version) #11347
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
f29ed29
0471d87
e549d48
6d5d90f
46fe0ed
7da3ffc
adb2aa9
9cd2b9b
9beffc6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,25 +18,27 @@ | |
| import sys | ||
| import warnings | ||
| import random | ||
| from itertools import chain | ||
|
|
||
| if sys.version >= '3': | ||
| basestring = unicode = str | ||
| long = int | ||
| from functools import reduce | ||
| else: | ||
| from itertools import imap as map | ||
| from itertools import imap as map, ifilter as filter | ||
|
|
||
| from pyspark import since | ||
| from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix | ||
| from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer | ||
| from pyspark.serializers import AutoBatchedSerializer, BatchedSerializer, \ | ||
| PickleSerializer, UTF8Deserializer | ||
| from pyspark.storagelevel import StorageLevel | ||
| from pyspark.traceback_utils import SCCallSiteSync | ||
| from pyspark.sql.types import _parse_datatype_json_string | ||
| from pyspark.sql.column import Column, _to_seq, _to_list, _to_java_column | ||
| from pyspark.sql.readwriter import DataFrameWriter | ||
| from pyspark.sql.types import * | ||
|
|
||
| __all__ = ["DataFrame", "DataFrameNaFunctions", "DataFrameStatFunctions"] | ||
| __all__ = ["DataFrame", "Dataset", "DataFrameNaFunctions", "DataFrameStatFunctions"] | ||
|
|
||
|
|
||
| class DataFrame(object): | ||
|
|
@@ -79,11 +81,18 @@ def __init__(self, jdf, sql_ctx): | |
| @property | ||
| @since(1.3) | ||
| def rdd(self): | ||
| """Returns the content as an :class:`pyspark.RDD` of :class:`Row`. | ||
| """Returns the content as an :class:`pyspark.RDD` of :class:`Row` or custom object. | ||
| """ | ||
| if self._lazy_rdd is None: | ||
| jrdd = self._jdf.javaToPython() | ||
| self._lazy_rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer())) | ||
| if self._jdf.isOutputPickled(): | ||
| # If the underlying java DataFrame's output is pickled, which means the query | ||
| # engine don't know the real schema of the data and just keep the pickled binary | ||
| # for each custom object(no batch). So we need to use non-batched serializer here. | ||
| deserializer = PickleSerializer() | ||
| else: | ||
| deserializer = BatchedSerializer(PickleSerializer()) | ||
| self._lazy_rdd = RDD(jrdd, self.sql_ctx._sc, deserializer) | ||
| return self._lazy_rdd | ||
|
|
||
| @property | ||
|
|
@@ -232,14 +241,12 @@ def count(self): | |
| @ignore_unicode_prefix | ||
| @since(1.3) | ||
| def collect(self): | ||
| """Returns all the records as a list of :class:`Row`. | ||
| """Returns all the records as a list. | ||
|
|
||
| >>> df.collect() | ||
| [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] | ||
| """ | ||
| with SCCallSiteSync(self._sc) as css: | ||
| port = self._jdf.collectToPython() | ||
| return list(_load_from_socket(port, BatchedSerializer(PickleSerializer()))) | ||
| return self.rdd.collect() | ||
|
||
|
|
||
| @ignore_unicode_prefix | ||
| @since(1.3) | ||
|
|
@@ -257,53 +264,85 @@ def limit(self, num): | |
| @ignore_unicode_prefix | ||
| @since(1.3) | ||
| def take(self, num): | ||
| """Returns the first ``num`` rows as a :class:`list` of :class:`Row`. | ||
| """Returns the first ``num`` records as a :class:`list`. | ||
|
|
||
| >>> df.take(2) | ||
| [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] | ||
| """ | ||
| with SCCallSiteSync(self._sc) as css: | ||
| port = self._sc._jvm.org.apache.spark.sql.execution.python.EvaluatePython.takeAndServe( | ||
| self._jdf, num) | ||
| return list(_load_from_socket(port, BatchedSerializer(PickleSerializer()))) | ||
| return self.limit(num).collect() | ||
|
||
|
|
||
| @ignore_unicode_prefix | ||
| @since(1.3) | ||
| def map(self, f): | ||
| """ Returns a new :class:`RDD` by applying a the ``f`` function to each :class:`Row`. | ||
| @since(2.0) | ||
| def applySchema(self, schema=None): | ||
| """Returns a new :class:`DataFrame` by appling the given schema, or infer the schema | ||
| by all of the records if no schema is given. | ||
|
|
||
| This is a shorthand for ``df.rdd.map()``. | ||
| It is only allowed to apply schema for DataFrame which is returned by typed operations, | ||
| e.g. map, flatMap, etc. And the record type of the schema-applied DataFrame will be row. | ||
|
|
||
| >>> df.map(lambda p: p.name).collect() | ||
| >>> ds = df.map(lambda row: row.name) | ||
| >>> ds.collect() | ||
| [u'Alice', u'Bob'] | ||
| >>> ds.schema | ||
| StructType(List(StructField(value,BinaryType,false))) | ||
| >>> ds2 = ds.applySchema(StringType()) | ||
| >>> ds2.collect() | ||
| [Row(value=u'Alice'), Row(value=u'Bob')] | ||
| >>> ds2.schema | ||
| StructType(List(StructField(value,StringType,true))) | ||
| >>> ds3 = ds.applySchema() | ||
| >>> ds3.collect() | ||
| [Row(value=u'Alice'), Row(value=u'Bob')] | ||
| >>> ds3.schema | ||
| StructType(List(StructField(value,StringType,true))) | ||
| """ | ||
| msg = "Cannot apply schema to a DataFrame which is not returned by typed operations" | ||
| raise RuntimeError(msg) | ||
|
||
|
|
||
| @ignore_unicode_prefix | ||
| @since(2.0) | ||
| def mapPartitions(self, func): | ||
| """Returns a new :class:`DataFrame` by applying the ``f`` function to each partition. | ||
|
|
||
| The schema of returned :class:`DataFrame` is a single binary field struct type, please | ||
| call `applySchema` to set the corrected schema before apply structured operations, e.g. | ||
| select, sort, groupBy, etc. | ||
|
|
||
| >>> def f(iterator): | ||
| ... return map(lambda i: 1, iterator) | ||
| >>> df.mapPartitions(f).collect() | ||
| [1, 1] | ||
| """ | ||
| return self.rdd.map(f) | ||
| return PipelinedDataFrame(self, func) | ||
|
|
||
| @ignore_unicode_prefix | ||
| @since(1.3) | ||
| def flatMap(self, f): | ||
| """ Returns a new :class:`RDD` by first applying the ``f`` function to each :class:`Row`, | ||
| and then flattening the results. | ||
| @since(2.0) | ||
|
||
| def map(self, func): | ||
| """ Returns a new :class:`DataFrame` by applying a the ``f`` function to each record. | ||
|
|
||
| This is a shorthand for ``df.rdd.flatMap()``. | ||
| The schema of returned :class:`DataFrame` is a single binary field struct type, please | ||
| call `applySchema` to set the corrected schema before apply structured operations, e.g. | ||
| select, sort, groupBy, etc. | ||
|
|
||
| >>> df.flatMap(lambda p: p.name).collect() | ||
| [u'A', u'l', u'i', u'c', u'e', u'B', u'o', u'b'] | ||
| >>> df.map(lambda p: p.name).collect() | ||
| [u'Alice', u'Bob'] | ||
| """ | ||
| return self.rdd.flatMap(f) | ||
| return self.mapPartitions(lambda iterator: map(func, iterator)) | ||
|
|
||
| @since(1.3) | ||
| def mapPartitions(self, f, preservesPartitioning=False): | ||
| """Returns a new :class:`RDD` by applying the ``f`` function to each partition. | ||
| @ignore_unicode_prefix | ||
| @since(2.0) | ||
| def flatMap(self, func): | ||
| """ Returns a new :class:`DataFrame` by first applying the ``f`` function to each record, | ||
| and then flattening the results. | ||
|
|
||
| This is a shorthand for ``df.rdd.mapPartitions()``. | ||
| The schema of returned :class:`DataFrame` is a single binary field struct type, please | ||
| call `applySchema` to set the corrected schema before apply structured operations, e.g. | ||
| select, sort, groupBy, etc. | ||
|
|
||
| >>> rdd = sc.parallelize([1, 2, 3, 4], 4) | ||
| >>> def f(iterator): yield 1 | ||
| >>> rdd.mapPartitions(f).sum() | ||
| 4 | ||
| >>> df.flatMap(lambda p: p.name).collect() | ||
| [u'A', u'l', u'i', u'c', u'e', u'B', u'o', u'b'] | ||
| """ | ||
| return self.rdd.mapPartitions(f, preservesPartitioning) | ||
| return self.mapPartitions(lambda iterator: chain.from_iterable(map(func, iterator))) | ||
|
|
||
| @since(1.3) | ||
| def foreach(self, f): | ||
|
|
@@ -315,7 +354,7 @@ def foreach(self, f): | |
| ... print(person.name) | ||
| >>> df.foreach(f) | ||
| """ | ||
| return self.rdd.foreach(f) | ||
| self.rdd.foreach(f) | ||
|
|
||
| @since(1.3) | ||
| def foreachPartition(self, f): | ||
|
|
@@ -328,7 +367,7 @@ def foreachPartition(self, f): | |
| ... print(person.name) | ||
| >>> df.foreachPartition(f) | ||
| """ | ||
| return self.rdd.foreachPartition(f) | ||
| self.rdd.foreachPartition(f) | ||
|
|
||
| @since(1.3) | ||
| def cache(self): | ||
|
|
@@ -745,7 +784,7 @@ def head(self, n=None): | |
|
|
||
| :param n: int, default 1. Number of rows to return. | ||
| :return: If n is greater than 1, return a list of :class:`Row`. | ||
| If n is 1, return a single Row. | ||
| If n is None, return a single Row. | ||
|
|
||
| >>> df.head() | ||
| Row(age=2, name=u'Alice') | ||
|
|
@@ -843,13 +882,20 @@ def selectExpr(self, *expr): | |
| @ignore_unicode_prefix | ||
| @since(1.3) | ||
| def filter(self, condition): | ||
| """Filters rows using the given condition. | ||
| """Filters records using the given condition. | ||
|
|
||
| :func:`where` is an alias for :func:`filter`. | ||
|
|
||
| :param condition: a :class:`Column` of :class:`types.BooleanType` | ||
| or a string of SQL expression. | ||
|
|
||
| .. versionchanged:: 2.0 | ||
| Also allows condition parameter to be a function that takes record as input and | ||
| returns boolean. | ||
| The schema of returned :class:`DataFrame` is a single binary field struct type, please | ||
| call `applySchema` to set the corrected schema before apply structured operations, e.g. | ||
| select, sort, groupBy, etc. | ||
|
|
||
| >>> df.filter(df.age > 3).collect() | ||
| [Row(age=5, name=u'Bob')] | ||
| >>> df.where(df.age == 2).collect() | ||
|
|
@@ -859,14 +905,20 @@ def filter(self, condition): | |
| [Row(age=5, name=u'Bob')] | ||
| >>> df.where("age = 2").collect() | ||
| [Row(age=2, name=u'Alice')] | ||
|
|
||
| >>> df.filter(lambda row: row.age > 3).collect() | ||
| [Row(age=5, name=u'Bob')] | ||
| >>> df.map(lambda row: row.age).filter(lambda age: age > 3).collect() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the type of This looks confusing to me.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. after map, the schema is
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the confusing part, the schema does match the record object type. |
||
| [5] | ||
| """ | ||
| if isinstance(condition, basestring): | ||
| jdf = self._jdf.filter(condition) | ||
| return DataFrame(self._jdf.filter(condition), self.sql_ctx) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This DataFrame could have schema or not, should we only allow this on typed DataFrame?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. DataFrame always have a schema(we have a default), the difference is: a DataFrame with default schema has custom objects as records, other DataFrames has rows as records. |
||
| elif isinstance(condition, Column): | ||
| jdf = self._jdf.filter(condition._jc) | ||
| return DataFrame(self._jdf.filter(condition._jc), self.sql_ctx) | ||
| elif hasattr(condition, '__call__'): | ||
| return self.mapPartitions(lambda iterator: filter(condition, iterator)) | ||
| else: | ||
| raise TypeError("condition should be string or Column") | ||
| return DataFrame(jdf, self.sql_ctx) | ||
|
|
||
| where = filter | ||
|
|
||
|
|
@@ -1404,6 +1456,88 @@ def toPandas(self): | |
| drop_duplicates = dropDuplicates | ||
|
|
||
|
|
||
| Dataset = DataFrame | ||
|
|
||
|
|
||
| class PipelinedDataFrame(DataFrame): | ||
|
|
||
| """ | ||
| Pipelined typed operations on :class:`DataFrame`: | ||
|
|
||
| >>> df.map(lambda row: 2 * row.age).cache().map(lambda i: 2 * i).collect() | ||
| [8, 20] | ||
| >>> df.map(lambda row: 2 * row.age).map(lambda i: 2 * i).collect() | ||
| [8, 20] | ||
| """ | ||
|
|
||
| def __init__(self, prev, func): | ||
| self._jdf_val = None | ||
| self._schema = None | ||
| self.is_cached = False | ||
| self.sql_ctx = prev.sql_ctx | ||
| self._sc = self.sql_ctx and self.sql_ctx._sc | ||
| self._lazy_rdd = None | ||
|
|
||
| if not isinstance(prev, PipelinedDataFrame) or prev.is_cached: | ||
| # This is the beginning of this pipeline. | ||
| self._func = func | ||
| self._prev_jdf = prev._jdf | ||
| else: | ||
| self._func = _pipeline_func(prev._func, func) | ||
| # maintain the pipeline. | ||
| self._prev_jdf = prev._prev_jdf | ||
|
|
||
| def applySchema(self, schema=None): | ||
| if schema is None: | ||
| from pyspark.sql.types import _infer_type, _merge_type | ||
| # If no schema is specified, infer it from the whole data set. | ||
| jrdd = self._prev_jdf.javaToPython() | ||
| rdd = RDD(jrdd, self._sc, BatchedSerializer(PickleSerializer())) | ||
| schema = rdd.mapPartitions(self._func).map(_infer_type).reduce(_merge_type) | ||
|
|
||
| if isinstance(schema, StructType): | ||
| to_rows = lambda iterator: map(schema.toInternal, iterator) | ||
| else: | ||
| data_type = schema | ||
| schema = StructType().add("value", data_type) | ||
| to_row = lambda obj: (data_type.toInternal(obj), ) | ||
| to_rows = lambda iterator: map(to_row, iterator) | ||
|
|
||
| wrapped_func = self._wrap_func(_pipeline_func(self._func, to_rows), False) | ||
| jdf = self._prev_jdf.pythonMapPartitions(wrapped_func, schema.json()) | ||
| return DataFrame(jdf, self.sql_ctx) | ||
|
|
||
| @property | ||
| def _jdf(self): | ||
| if self._jdf_val is None: | ||
| wrapped_func = self._wrap_func(self._func, True) | ||
| self._jdf_val = self._prev_jdf.pythonMapPartitions(wrapped_func) | ||
| return self._jdf_val | ||
|
|
||
| def _wrap_func(self, func, output_binary): | ||
| if self._prev_jdf.isOutputPickled(): | ||
| deserializer = PickleSerializer() | ||
| else: | ||
| deserializer = AutoBatchedSerializer(PickleSerializer()) | ||
|
|
||
| if output_binary: | ||
| serializer = PickleSerializer() | ||
| else: | ||
| serializer = AutoBatchedSerializer(PickleSerializer()) | ||
|
|
||
| from pyspark.rdd import _wrap_function | ||
| return _wrap_function(self._sc, lambda _, iterator: func(iterator), | ||
| deserializer, serializer) | ||
|
|
||
|
|
||
| def _pipeline_func(prev_func, next_func): | ||
| """ | ||
| Pipeline 2 functions into one, while each of these 2 functions takes an iterator and | ||
| returns an iterator. | ||
| """ | ||
| return lambda iterator: next_func(prev_func(iterator)) | ||
|
|
||
|
|
||
| def _to_scala_map(sc, jm): | ||
| """ | ||
| Convert a dict into a JVM Map. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The overhead of PickleSerializer is pretty high, it will serialize the class for each row, could you do some benchmark to see how is the difference between non-batched vs batched (both size and CPU time)?