Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 163 additions & 21 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* @since 1.4.0
*/
def load(): DataFrame = {
val dataSource =
DataSource(
sparkSession,
userSpecifiedSchema = userSpecifiedSchema,
className = source,
options = extraOptions.toMap)
Dataset.ofRows(sparkSession, LogicalRelation(dataSource.resolveRelation()))
load(Seq.empty: _*) // force invocation of `load(...varargs...)`
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deduped.

}

/**
Expand All @@ -146,18 +140,15 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
*/
@scala.annotation.varargs
def load(paths: String*): DataFrame = {
if (paths.isEmpty) {
sparkSession.emptyDataFrame
} else {
sparkSession.baseRelationToDataFrame(
DataSource.apply(
sparkSession,
paths = paths,
userSpecifiedSchema = userSpecifiedSchema,
className = source,
options = extraOptions.toMap).resolveRelation())
}
sparkSession.baseRelationToDataFrame(
DataSource.apply(
sparkSession,
paths = paths,
userSpecifiedSchema = userSpecifiedSchema,
className = source,
options = extraOptions.toMap).resolveRelation())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed the special handling of empty paths

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my PR, will add the test cases to verify all the possible inputs after this code changes. Thanks!

}

/**
* Construct a [[DataFrame]] representing the database table accessible via JDBC URL
* url named table and connection properties.
Expand Down Expand Up @@ -276,7 +267,42 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string
* created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.</li>
*
* @since 1.6.0
* @since 1.4.0
*/
def json(path: String): DataFrame = json(Seq(path): _*)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

made this method depend on json(varargs) to prevent code duplication.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should add some comment inline on why this exists. ditto for similar functions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. agreed. will add as inline comments, not as scala docs.


/**
* Loads a JSON file (one object per line) and returns the result as a [[DataFrame]].
*
* This function goes through the input once to determine the input schema. If you know the
* schema in advance, use the version that specifies the schema to avoid the extra scan.
*
* You can set the following JSON-specific options to deal with non-standard JSON files:
* <li>`primitivesAsString` (default `false`): infers all primitive values as a string type</li>
* <li>`prefersDecimal` (default `false`): infers all floating-point values as a decimal
* type. If the values do not fit in decimal, then it infers them as doubles.</li>
* <li>`allowComments` (default `false`): ignores Java/C++ style comment in JSON records</li>
* <li>`allowUnquotedFieldNames` (default `false`): allows unquoted JSON field names</li>
* <li>`allowSingleQuotes` (default `true`): allows single quotes in addition to double quotes
* </li>
* <li>`allowNumericLeadingZeros` (default `false`): allows leading zeros in numbers
* (e.g. 00012)</li>
* <li>`allowBackslashEscapingAnyCharacter` (default `false`): allows accepting quoting of all
* character using backslash quoting mechanism</li>
* <li>`mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records
* during parsing.</li>
* <ul>
* <li>`PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this does not indent correctly:

screen shot 2016-06-17 at 2 07 18 pm

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be better to just use the intented - notation for lists

* malformed string into a new field configured by `columnNameOfCorruptRecord`. When
* a schema is set by user, it sets `null` for extra fields.</li>
* <li>`DROPMALFORMED` : ignores the whole corrupted records.</li>
* <li>`FAILFAST` : throws an exception when it meets corrupted records.</li>
* </ul>
* <li>`columnNameOfCorruptRecord` (default is the value specified in
* `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string
* created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.</li>
*
* @since 2.0.0
*/
@scala.annotation.varargs
def json(paths: String*): DataFrame = format("json").load(paths : _*)
Expand Down Expand Up @@ -326,6 +352,60 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
parsedOptions))(sparkSession))
}

/**
* Loads a CSV file and returns the result as a [[DataFrame]].
*
* This function will go through the input once to determine the input schema if `inferSchema`
* is enabled. To avoid going through the entire data once, disable `inferSchema` option or
* specify the schema explicitly using [[schema]].
*
* You can set the following CSV-specific options to deal with CSV files:
* <li>`sep` (default `,`): sets the single character as a separator for each
* field and value.</li>
* <li>`encoding` (default `UTF-8`): decodes the CSV files by the given encoding
* type.</li>
* <li>`quote` (default `"`): sets the single character used for escaping quoted values where
* the separator can be part of the value. If you would like to turn off quotations, you need to
* set not `null` but an empty string. This behaviour is different form
* `com.databricks.spark.csv`.</li>
* <li>`escape` (default `\`): sets the single character used for escaping quotes inside
* an already quoted value.</li>
* <li>`comment` (default empty string): sets the single character used for skipping lines
* beginning with this character. By default, it is disabled.</li>
* <li>`header` (default `false`): uses the first line as names of columns.</li>
* <li>`inferSchema` (default `false`): infers the input schema automatically from data. It
* requires one extra pass over the data.</li>
* <li>`ignoreLeadingWhiteSpace` (default `false`): defines whether or not leading whitespaces
* from values being read should be skipped.</li>
* <li>`ignoreTrailingWhiteSpace` (default `false`): defines whether or not trailing
* whitespaces from values being read should be skipped.</li>
* <li>`nullValue` (default empty string): sets the string representation of a null value.</li>
* <li>`nanValue` (default `NaN`): sets the string representation of a non-number" value.</li>
* <li>`positiveInf` (default `Inf`): sets the string representation of a positive infinity
* value.</li>
* <li>`negativeInf` (default `-Inf`): sets the string representation of a negative infinity
* value.</li>
* <li>`dateFormat` (default `null`): sets the string that indicates a date format. Custom date
* formats follow the formats at `java.text.SimpleDateFormat`. This applies to both date type
* and timestamp type. By default, it is `null` which means trying to parse times and date by
* `java.sql.Timestamp.valueOf()` and `java.sql.Date.valueOf()`.</li>
* <li>`maxColumns` (default `20480`): defines a hard limit of how many columns
* a record can have.</li>
* <li>`maxCharsPerColumn` (default `1000000`): defines the maximum number of characters allowed
* for any given value being read.</li>
* <li>`mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records
* during parsing.</li>
* <ul>
* <li>`PERMISSIVE` : sets other fields to `null` when it meets a corrupted record. When
* a schema is set by user, it sets `null` for extra fields.</li>
* <li>`DROPMALFORMED` : ignores the whole corrupted records.</li>
* <li>`FAILFAST` : throws an exception when it meets corrupted records.</li>
* </ul>
*
* @since 2.0.0
*/
def csv(path: String): DataFrame = csv(Seq(path): _*)

/**
* Loads a CSV file and returns the result as a [[DataFrame]].
*
Expand Down Expand Up @@ -381,6 +461,19 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
@scala.annotation.varargs
def csv(paths: String*): DataFrame = format("csv").load(paths : _*)

/**
* Loads a Parquet file, returning the result as a [[DataFrame]]. This function returns an empty
* [[DataFrame]] if no paths are passed in.
*
* You can set the following Parquet-specific option(s) for reading Parquet files:
* <li>`mergeSchema` (default is the value specified in `spark.sql.parquet.mergeSchema`): sets
* whether we should merge schemas collected from all Parquet part-files. This will override
* `spark.sql.parquet.mergeSchema`.</li>
*
* @since 2.0.0
*/
def parquet(path: String): DataFrame = parquet(Seq(path): _*)

/**
* Loads a Parquet file, returning the result as a [[DataFrame]]. This function returns an empty
* [[DataFrame]] if no paths are passed in.
Expand All @@ -404,7 +497,17 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* @since 1.5.0
* @note Currently, this method can only be used after enabling Hive support.
*/
def orc(path: String): DataFrame = format("orc").load(path)
def orc(path: String): DataFrame = orc(Seq(path): _*)

/**
* Loads an ORC file and returns the result as a [[DataFrame]].
*
* @param paths input paths
* @since 2.0.0
* @note Currently, this method can only be used after enabling Hive support.
*/
@scala.annotation.varargs
def orc(paths: String*): DataFrame = format("orc").load(paths: _*)

/**
* Returns the specified table as a [[DataFrame]].
Expand All @@ -430,12 +533,51 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* spark.read().text("/path/to/spark/README.md")
* }}}
*
* @param paths input path
* @param path input path
* @since 2.0.0
*/
def text(path: String): DataFrame = text(Seq(path): _*)

/**
* Loads text files and returns a [[DataFrame]] whose schema starts with a string column named
* "value", and followed by partitioned columns if there are any.
*
* Each line in the text files is a new row in the resulting DataFrame. For example:
* {{{
* // Scala:
* spark.read.text("/path/to/spark/README.md")
*
* // Java:
* spark.read().text("/path/to/spark/README.md")
* }}}
*
* @param paths input paths
* @since 1.6.0
*/
@scala.annotation.varargs
def text(paths: String*): DataFrame = format("text").load(paths : _*)

/**
* Loads text files and returns a [[Dataset]] of String. The underlying schema of the Dataset
* contains a single string column named "value".
*
* If the directory structure of the text files contains partitioning information, those are
* ignored in the resulting Dataset. To include partitioning information as columns, use `text`.
*
* Each line in the text files is a new element in the resulting Dataset. For example:
* {{{
* // Scala:
* spark.read.textFile("/path/to/spark/README.md")
*
* // Java:
* spark.read().textFile("/path/to/spark/README.md")
* }}}
*
* @param path input path
* @since 2.0.0
*/
def textFile(path: String): Dataset[String] = textFile(Seq(path): _*)

/**
* Loads text files and returns a [[Dataset]] of String. The underlying schema of the Dataset
* contains a single string column named "value".
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
package test.org.apache.spark.sql;

import java.io.File;
import java.util.Arrays;
import java.util.HashMap;

import org.apache.spark.api.java.function.Function0;
import org.apache.spark.sql.AnalysisException;
import org.apache.spark.sql.SaveMode;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.test.TestSparkSession;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.util.Utils;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;

public class JavaDataFrameReaderWriterSuite {
private SparkSession spark = new TestSparkSession();
private StructType schema = new StructType().add("s", "string");
private transient String input;
private transient String output;

@Before
public void setUp() {
input = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "input").toString();
File f = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "output");
f.delete();
output = f.toString();
}

@After
public void tearDown() {
spark.stop();
spark = null;
}

@Test
public void testFormatAPI() {
spark
.read()
.format("org.apache.spark.sql.test")
.load()
.write()
.format("org.apache.spark.sql.test")
.save();
}

@Test
public void testOptionsAPI() {
HashMap<String, String> map = new HashMap<String, String>();
map.put("e", "1");
spark
.read()
.option("a", "1")
.option("b", 1)
.option("c", 1.0)
.option("d", true)
.options(map)
.text()
.write()
.option("a", "1")
.option("b", 1)
.option("c", 1.0)
.option("d", true)
.options(map)
.format("org.apache.spark.sql.test")
.save();
}

@Test
public void testSaveModeAPI() {
spark
.range(10)
.write()
.format("org.apache.spark.sql.test")
.mode(SaveMode.ErrorIfExists)
.save();
}

@Test
public void testLoadAPI() {
spark.read().format("org.apache.spark.sql.test").load();
spark.read().format("org.apache.spark.sql.test").load(input);
spark.read().format("org.apache.spark.sql.test").load(input, input, input);
spark.read().format("org.apache.spark.sql.test").load(new String[]{input, input});
}

@Test
public void testTextAPI() {
spark.read().text();
spark.read().text(input);
spark.read().text(input, input, input);
spark.read().text(new String[]{input, input})
.write().text(output);
}

@Test
public void testTextFileAPI() {
spark.read().textFile(); // Disabled because of SPARK-XXXXX
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SPARK-XXXX?

spark.read().textFile(input);
spark.read().textFile(input, input, input);
spark.read().textFile(new String[]{input, input});
}

@Test
public void testCsvAPI() {
spark.read().schema(schema).csv();
spark.read().schema(schema).csv(input);
spark.read().schema(schema).csv(input, input, input);
spark.read().schema(schema).csv(new String[]{input, input})
.write().csv(output);
}

@Test
public void testJsonAPI() {
spark.read().schema(schema).json();
spark.read().schema(schema).json(input);
spark.read().schema(schema).json(input, input, input);
spark.read().schema(schema).json(new String[]{input, input})
.write().json(output);
}

@Test
public void testParquetAPI() {
spark.read().schema(schema).parquet();
spark.read().schema(schema).parquet(input);
spark.read().schema(schema).parquet(input, input, input);
spark.read().schema(schema).parquet(new String[] { input, input })
.write().parquet(output);
}

/**
* This only tests whether API compiles, but does not run it as orc()
* cannot be run with Hive classes.
*/
public void testOrcAPI() {
spark.read().schema(schema).orc();
spark.read().schema(schema).orc(input);
spark.read().schema(schema).orc(input, input, input);
spark.read().schema(schema).orc(new String[]{input, input})
.write().orc(output);
}
}
Loading