Skip to content

Commit 2d26b95

Browse files
maropulycplus
authored andcommitted
[SPARK-20431][SQL] Specify a schema by using a DDL-formatted string
## What changes were proposed in this pull request? This pr supported a DDL-formatted string in `DataFrameReader.schema`. This fix could make users easily define a schema without importing `o.a.spark.sql.types._`. ## How was this patch tested? Added tests in `DataFrameReaderWriterSuite`. Author: Takeshi Yamamuro <[email protected]> Closes apache#17719 from maropu/SPARK-20431.
1 parent df7b47b commit 2d26b95

3 files changed

Lines changed: 36 additions & 8 deletions

File tree

python/pyspark/sql/readwriter.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -96,14 +96,18 @@ def schema(self, schema):
9696
By specifying the schema here, the underlying data source can skip the schema
9797
inference step, and thus speed up data loading.
9898
99-
:param schema: a :class:`pyspark.sql.types.StructType` object
99+
:param schema: a :class:`pyspark.sql.types.StructType` object or a DDL-formatted string
100+
(For example ``col0 INT, col1 DOUBLE``).
100101
"""
101102
from pyspark.sql import SparkSession
102-
if not isinstance(schema, StructType):
103-
raise TypeError("schema should be StructType")
104103
spark = SparkSession.builder.getOrCreate()
105-
jschema = spark._jsparkSession.parseDataType(schema.json())
106-
self._jreader = self._jreader.schema(jschema)
104+
if isinstance(schema, StructType):
105+
jschema = spark._jsparkSession.parseDataType(schema.json())
106+
self._jreader = self._jreader.schema(jschema)
107+
elif isinstance(schema, basestring):
108+
self._jreader = self._jreader.schema(schema)
109+
else:
110+
raise TypeError("schema should be StructType or string")
107111
return self
108112

109113
@since(1.5)
@@ -137,7 +141,8 @@ def load(self, path=None, format=None, schema=None, **options):
137141
138142
:param path: optional string or a list of string for file-system backed data sources.
139143
:param format: optional string for format of the data source. Default to 'parquet'.
140-
:param schema: optional :class:`pyspark.sql.types.StructType` for the input schema.
144+
:param schema: optional :class:`pyspark.sql.types.StructType` for the input schema
145+
or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``).
141146
:param options: all other string options
142147
143148
>>> df = spark.read.load('python/test_support/sql/parquet_partitioned', opt1=True,
@@ -181,7 +186,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
181186
182187
:param path: string represents path to the JSON dataset, or a list of paths,
183188
or RDD of Strings storing JSON objects.
184-
:param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema.
189+
:param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema or
190+
a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``).
185191
:param primitivesAsString: infers all primitive values as a string type. If None is set,
186192
it uses the default value, ``false``.
187193
:param prefersDecimal: infers all floating-point values as a decimal type. If the values
@@ -324,7 +330,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
324330
``inferSchema`` option or specify the schema explicitly using ``schema``.
325331
326332
:param path: string, or list of strings, for input path(s).
327-
:param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema.
333+
:param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema
334+
or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``).
328335
:param sep: sets the single character as a separator for each field and value.
329336
If None is set, it uses the default value, ``,``.
330337
:param encoding: decodes the CSV files by the given encoding type. If None is set,

sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,18 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
6767
this
6868
}
6969

70+
/**
71+
* Specifies the schema by using the input DDL-formatted string. Some data sources (e.g. JSON) can
72+
* infer the input schema automatically from data. By specifying the schema here, the underlying
73+
* data source can skip the schema inference step, and thus speed up data loading.
74+
*
75+
* @since 2.3.0
76+
*/
77+
def schema(schemaString: String): DataFrameReader = {
78+
this.userSpecifiedSchema = Option(StructType.fromDDL(schemaString))
79+
this
80+
}
81+
7082
/**
7183
* Adds an input option for the underlying data source.
7284
*

sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be
128128
import testImplicits._
129129

130130
private val userSchema = new StructType().add("s", StringType)
131+
private val userSchemaString = "s STRING"
131132
private val textSchema = new StructType().add("value", StringType)
132133
private val data = Seq("1", "2", "3")
133134
private val dir = Utils.createTempDir(namePrefix = "input").getCanonicalPath
@@ -678,4 +679,12 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be
678679
assert(e.contains("User specified schema not supported with `table`"))
679680
}
680681
}
682+
683+
test("SPARK-20431: Specify a schema by using a DDL-formatted string") {
684+
spark.createDataset(data).write.mode(SaveMode.Overwrite).text(dir)
685+
testRead(spark.read.schema(userSchemaString).text(), Seq.empty, userSchema)
686+
testRead(spark.read.schema(userSchemaString).text(dir), data, userSchema)
687+
testRead(spark.read.schema(userSchemaString).text(dir, dir), data ++ data, userSchema)
688+
testRead(spark.read.schema(userSchemaString).text(Seq(dir, dir): _*), data ++ data, userSchema)
689+
}
681690
}

0 commit comments

Comments
 (0)