Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 136 additions & 25 deletions R/pkg/R/SQLContext.R
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,70 @@ getDefaultSqlSource <- function() {
l[["spark.sql.sources.default"]]
}

writeToFileInArrow <- function(fileName, rdf, numPartitions) {
requireNamespace1 <- requireNamespace

# R API in Arrow is not yet released in CRAN. CRAN requires to add the
# package in requireNamespace at DESCRIPTION. Later, CRAN checks if the package is available
# or not. Therefore, it works around by avoiding direct requireNamespace.
# Currently, as of Arrow 0.12.0, it can be installed by install_github. See ARROW-3204.
if (requireNamespace1("arrow", quietly = TRUE)) {
record_batch <- get("record_batch", envir = asNamespace("arrow"), inherits = FALSE)
RecordBatchStreamWriter <- get(
"RecordBatchStreamWriter", envir = asNamespace("arrow"), inherits = FALSE)
FileOutputStream <- get(
"FileOutputStream", envir = asNamespace("arrow"), inherits = FALSE)

numPartitions <- if (!is.null(numPartitions)) {
numToInt(numPartitions)
} else {
1
}

rdf_slices <- if (numPartitions > 1) {
split(rdf, makeSplits(numPartitions, nrow(rdf)))
} else {
list(rdf)
}

stream_writer <- NULL
tryCatch({
for (rdf_slice in rdf_slices) {
batch <- record_batch(rdf_slice)
if (is.null(stream_writer)) {
stream <- FileOutputStream(fileName)
schema <- batch$schema
stream_writer <- RecordBatchStreamWriter(stream, schema)
}

stream_writer$write_batch(batch)
}
},
finally = {
if (!is.null(stream_writer)) {
stream_writer$close()
}
})

} else {
stop("'arrow' package should be installed.")
}
}

checkTypeRequirementForArrow <- function(dataHead, schema) {
# Currenty Arrow optimization does not support raw for now.
# Also, it does not support explicit float type set by users. It leads to
# incorrect conversion. We will fall back to the path without Arrow optimization.
if (any(sapply(dataHead, is.raw))) {
stop("Arrow optimization with R DataFrame does not support raw type yet.")
}
if (inherits(schema, "structType")) {
if (any(sapply(schema$fields(), function(x) x$dataType.toString() == "FloatType"))) {
stop("Arrow optimization with R DataFrame does not support FloatType type yet.")
}
}
}

#' Create a SparkDataFrame
#'
#' Converts R data.frame or list into SparkDataFrame.
Expand All @@ -172,36 +236,76 @@ getDefaultSqlSource <- function() {
createDataFrame <- function(data, schema = NULL, samplingRatio = 1.0,
numPartitions = NULL) {
sparkSession <- getSparkSession()
arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.enabled")[[1]] == "true"
useArrow <- FALSE
firstRow <- NULL

if (is.data.frame(data)) {
# Convert data into a list of rows. Each row is a list.

# get the names of columns, they will be put into RDD
if (is.null(schema)) {
schema <- names(data)
}
# get the names of columns, they will be put into RDD
if (is.null(schema)) {
schema <- names(data)
}

# get rid of factor type
cleanCols <- function(x) {
if (is.factor(x)) {
as.character(x)
} else {
x
}
# get rid of factor type
cleanCols <- function(x) {
if (is.factor(x)) {
as.character(x)
} else {
x
}
}
data[] <- lapply(data, cleanCols)

args <- list(FUN = list, SIMPLIFY = FALSE, USE.NAMES = FALSE)
if (arrowEnabled) {
useArrow <- tryCatch({
stopifnot(length(data) > 0)
dataHead <- head(data, 1)
checkTypeRequirementForArrow(data, schema)
fileName <- tempfile(pattern = "sparwriteToFileInArrowk-arrow", fileext = ".tmp")
tryCatch({
writeToFileInArrow(fileName, data, numPartitions)
jrddInArrow <- callJStatic("org.apache.spark.sql.api.r.SQLUtils",
"readArrowStreamFromFile",
sparkSession,
fileName)
},
finally = {
# File might not be created.
suppressWarnings(file.remove(fileName))
})

firstRow <- do.call(mapply, append(args, dataHead))[[1]]
TRUE
},
error = function(e) {
warning(paste0("createDataFrame attempted Arrow optimization because ",
"'spark.sql.execution.arrow.enabled' is set to true; however, ",
"failed, attempting non-optimization. Reason: ",
e))
FALSE
})
}

if (!useArrow) {
# Convert data into a list of rows. Each row is a list.
# drop factors and wrap lists
data <- setNames(lapply(data, cleanCols), NULL)
data <- setNames(as.list(data), NULL)

# check if all columns have supported type
lapply(data, getInternalType)

# convert to rows
args <- list(FUN = list, SIMPLIFY = FALSE, USE.NAMES = FALSE)
data <- do.call(mapply, append(args, data))
if (length(data) > 0) {
firstRow <- data[[1]]
}
}
}

if (is.list(data)) {
if (useArrow) {
rdd <- jrddInArrow
} else if (is.list(data)) {
sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession)
if (!is.null(numPartitions)) {
rdd <- parallelize(sc, data, numSlices = numToInt(numPartitions))
Expand All @@ -215,14 +319,16 @@ createDataFrame <- function(data, schema = NULL, samplingRatio = 1.0,
}

if (is.null(schema) || (!inherits(schema, "structType") && is.null(names(schema)))) {
row <- firstRDD(rdd)
if (is.null(firstRow)) {
firstRow <- firstRDD(rdd)
}
names <- if (is.null(schema)) {
names(row)
names(firstRow)
} else {
as.list(schema)
}
if (is.null(names)) {
names <- lapply(1:length(row), function(x) {
names <- lapply(1:length(firstRow), function(x) {
paste("_", as.character(x), sep = "")
})
}
Expand All @@ -237,19 +343,24 @@ createDataFrame <- function(data, schema = NULL, samplingRatio = 1.0,
nn
})

types <- lapply(row, infer_type)
fields <- lapply(1:length(row), function(i) {
types <- lapply(firstRow, infer_type)
fields <- lapply(1:length(firstRow), function(i) {
structField(names[[i]], types[[i]], TRUE)
})
schema <- do.call(structType, fields)
}

stopifnot(class(schema) == "structType")

jrdd <- getJRDD(lapply(rdd, function(x) x), "row")
srdd <- callJMethod(jrdd, "rdd")
sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "createDF",
srdd, schema$jobj, sparkSession)
if (useArrow) {
sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils",
"toDataFrame", rdd, schema$jobj, sparkSession)
} else {
jrdd <- getJRDD(lapply(rdd, function(x) x), "row")
srdd <- callJMethod(jrdd, "rdd")
sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "createDF",
srdd, schema$jobj, sparkSession)
}
dataFrame(sdf)
}

Expand Down
40 changes: 21 additions & 19 deletions R/pkg/R/context.R
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,26 @@ objectFile <- function(sc, path, minPartitions = NULL) {
RDD(jrdd, "byte")
}

makeSplits <- function(numSerializedSlices, length) {
# Generate the slice ids to put each row
# For instance, for numSerializedSlices of 22, length of 50
# [1] 0 0 2 2 4 4 6 6 6 9 9 11 11 13 13 15 15 15 18 18 20 20 22 22 22
# [26] 25 25 27 27 29 29 31 31 31 34 34 36 36 38 38 40 40 40 43 43 45 45 47 47 47
# Notice the slice group with 3 slices (ie. 6, 15, 22) are roughly evenly spaced.
# We are trying to reimplement the calculation in the positions method in ParallelCollectionRDD
if (numSerializedSlices > 0) {
unlist(lapply(0: (numSerializedSlices - 1), function(x) {
# nolint start
start <- trunc((as.numeric(x) * length) / numSerializedSlices)
end <- trunc(((as.numeric(x) + 1) * length) / numSerializedSlices)
# nolint end
rep(start, end - start)
}))
} else {
1
}
}

#' Create an RDD from a homogeneous list or vector.
#'
#' This function creates an RDD from a local homogeneous list in R. The elements
Expand Down Expand Up @@ -143,25 +163,7 @@ parallelize <- function(sc, coll, numSlices = 1) {
# For large objects we make sure the size of each slice is also smaller than sizeLimit
numSerializedSlices <- min(len, max(numSlices, ceiling(objectSize / sizeLimit)))

# Generate the slice ids to put each row
# For instance, for numSerializedSlices of 22, length of 50
# [1] 0 0 2 2 4 4 6 6 6 9 9 11 11 13 13 15 15 15 18 18 20 20 22 22 22
# [26] 25 25 27 27 29 29 31 31 31 34 34 36 36 38 38 40 40 40 43 43 45 45 47 47 47
# Notice the slice group with 3 slices (ie. 6, 15, 22) are roughly evenly spaced.
# We are trying to reimplement the calculation in the positions method in ParallelCollectionRDD
splits <- if (numSerializedSlices > 0) {
unlist(lapply(0: (numSerializedSlices - 1), function(x) {
# nolint start
start <- trunc((as.numeric(x) * len) / numSerializedSlices)
end <- trunc(((as.numeric(x) + 1) * len) / numSerializedSlices)
# nolint end
rep(start, end - start)
}))
} else {
1
}

slices <- split(coll, splits)
slices <- split(coll, makeSplits(numSerializedSlices, len))

# Serialize each slice: obtain a list of raws, or a list of lists (slices) of
# 2-tuples of raws
Expand Down
57 changes: 57 additions & 0 deletions R/pkg/tests/fulltests/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,63 @@ test_that("create DataFrame from RDD", {
unsetHiveContext()
})

test_that("createDataFrame Arrow optimization", {
skip_if_not_installed("arrow")

conf <- callJMethod(sparkSession, "conf")
arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.enabled")[[1]]

callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", "false")
tryCatch({
expected <- collect(createDataFrame(mtcars))
},
finally = {
# Resetting the conf back to default value
callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", arrowEnabled)
})

callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", "true")
tryCatch({
expect_equal(collect(createDataFrame(mtcars)), expected)
},
finally = {
# Resetting the conf back to default value
callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", arrowEnabled)
})
})

test_that("createDataFrame Arrow optimization - type specification", {
skip_if_not_installed("arrow")
rdf <- data.frame(list(list(a = 1,
b = "a",
c = TRUE,
d = 1.1,
e = 1L,
f = as.Date("1990-02-24"),
g = as.POSIXct("1990-02-24 12:34:56"))))

arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.enabled")[[1]]
conf <- callJMethod(sparkSession, "conf")

callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", "false")
tryCatch({
expected <- collect(createDataFrame(rdf))
},
finally = {
# Resetting the conf back to default value
callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", arrowEnabled)
})

callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", "true")
tryCatch({
expect_equal(collect(createDataFrame(rdf)), expected)
},
finally = {
# Resetting the conf back to default value
callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", arrowEnabled)
})
})

test_that("read/write csv as DataFrame", {
if (windows_with_hadoop()) {
csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1284,8 +1284,9 @@ object SQLConf {
val ARROW_EXECUTION_ENABLED =
buildConf("spark.sql.execution.arrow.enabled")
.doc("When true, make use of Apache Arrow for columnar data transfers. Currently available " +
"for use with pyspark.sql.DataFrame.toPandas, and " +
"pyspark.sql.SparkSession.createDataFrame when its input is a Pandas DataFrame. " +
"for use with pyspark.sql.DataFrame.toPandas, " +
"pyspark.sql.SparkSession.createDataFrame when its input is a Pandas DataFrame, " +
"and createDataFrame when its input is an R DataFrame. " +
"The following data types are unsupported: " +
"BinaryType, MapType, ArrayType of TimestampType, and nested StructType.")
.booleanConf
Expand Down
22 changes: 22 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.{ExprUtils, GenericRowWithSchema}
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.execution.command.ShowTablesCommand
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -237,4 +238,25 @@ private[sql] object SQLUtils extends Logging {
def createArrayType(column: Column): ArrayType = {
new ArrayType(ExprUtils.evalTypeExpr(column.expr), true)
}

/**
* R callable function to read a file in Arrow stream format and create an `RDD`
* using each serialized ArrowRecordBatch as a partition.
*/
def readArrowStreamFromFile(
sparkSession: SparkSession,
filename: String): JavaRDD[Array[Byte]] = {
ArrowConverters.readArrowStreamFromFile(sparkSession.sqlContext, filename)
}

/**
* R callable function to create a `DataFrame` from a `JavaRDD` of serialized
* ArrowRecordBatches.
*/
def toDataFrame(
arrowBatchRDD: JavaRDD[Array[Byte]],
schema: StructType,
sparkSession: SparkSession): DataFrame = {
ArrowConverters.toDataFrame(arrowBatchRDD, schema.json, sparkSession.sqlContext)
}
}