Skip to content

Commit 1b7fa3d

Browse files
author
Andrew Or
committed
Merge branch 'master' of github.com:apache/spark into spilling-tests
2 parents 7226933 + 9a430a0 commit 1b7fa3d

45 files changed

Lines changed: 1267 additions & 330 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

R/pkg/NAMESPACE

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ exportMethods("arrange",
6565
"repartition",
6666
"sample",
6767
"sample_frac",
68+
"sampleBy",
6869
"saveAsParquetFile",
6970
"saveAsTable",
7071
"saveDF",
@@ -254,4 +255,4 @@ export("structField",
254255
"structType.structField",
255256
"print.structType")
256257

257-
export("as.data.frame")
258+
export("as.data.frame")

R/pkg/R/DataFrame.R

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1414,9 +1414,10 @@ setMethod("where",
14141414
#' @param x A Spark DataFrame
14151415
#' @param y A Spark DataFrame
14161416
#' @param joinExpr (Optional) The expression used to perform the join. joinExpr must be a
1417-
#' Column expression. If joinExpr is omitted, join() wil perform a Cartesian join
1417+
#' Column expression. If joinExpr is omitted, join() will perform a Cartesian join
14181418
#' @param joinType The type of join to perform. The following join types are available:
1419-
#' 'inner', 'outer', 'left_outer', 'right_outer', 'semijoin'. The default joinType is "inner".
1419+
#' 'inner', 'outer', 'full', 'fullouter', leftouter', 'left_outer', 'left',
1420+
#' 'right_outer', 'rightouter', 'right', and 'leftsemi'. The default joinType is "inner".
14201421
#' @return A DataFrame containing the result of the join operation.
14211422
#' @rdname join
14221423
#' @name join
@@ -1441,11 +1442,15 @@ setMethod("join",
14411442
if (is.null(joinType)) {
14421443
sdf <- callJMethod(x@sdf, "join", y@sdf, joinExpr@jc)
14431444
} else {
1444-
if (joinType %in% c("inner", "outer", "left_outer", "right_outer", "semijoin")) {
1445+
if (joinType %in% c("inner", "outer", "full", "fullouter",
1446+
"leftouter", "left_outer", "left",
1447+
"rightouter", "right_outer", "right", "leftsemi")) {
1448+
joinType <- gsub("_", "", joinType)
14451449
sdf <- callJMethod(x@sdf, "join", y@sdf, joinExpr@jc, joinType)
14461450
} else {
14471451
stop("joinType must be one of the following types: ",
1448-
"'inner', 'outer', 'left_outer', 'right_outer', 'semijoin'")
1452+
"'inner', 'outer', 'full', 'fullouter', 'leftouter', 'left_outer', 'left',
1453+
'rightouter', 'right_outer', 'right', 'leftsemi'")
14491454
}
14501455
}
14511456
}
@@ -1826,17 +1831,15 @@ setMethod("fillna",
18261831
if (length(colNames) == 0 || !all(colNames != "")) {
18271832
stop("value should be an a named list with each name being a column name.")
18281833
}
1829-
1830-
# Convert to the named list to an environment to be passed to JVM
1831-
valueMap <- new.env()
1832-
for (col in colNames) {
1833-
# Check each item in the named list is of valid type
1834-
v <- value[[col]]
1834+
# Check each item in the named list is of valid type
1835+
lapply(value, function(v) {
18351836
if (!(class(v) %in% c("integer", "numeric", "character"))) {
18361837
stop("Each item in value should be an integer, numeric or charactor.")
18371838
}
1838-
valueMap[[col]] <- v
1839-
}
1839+
})
1840+
1841+
# Convert to the named list to an environment to be passed to JVM
1842+
valueMap <- convertNamedListToEnv(value)
18401843

18411844
# When value is a named list, caller is expected not to pass in cols
18421845
if (!is.null(cols)) {

R/pkg/R/generics.R

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,10 @@ setGeneric("sample",
509509
setGeneric("sample_frac",
510510
function(x, withReplacement, fraction, seed) { standardGeneric("sample_frac") })
511511

512+
#' @rdname statfunctions
513+
#' @export
514+
setGeneric("sampleBy", function(x, col, fractions, seed) { standardGeneric("sampleBy") })
515+
512516
#' @rdname saveAsParquetFile
513517
#' @export
514518
setGeneric("saveAsParquetFile", function(x, path) { standardGeneric("saveAsParquetFile") })
@@ -1006,4 +1010,4 @@ setGeneric("as.data.frame")
10061010

10071011
#' @rdname attach
10081012
#' @export
1009-
setGeneric("attach")
1013+
setGeneric("attach")

R/pkg/R/sparkR.R

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -163,19 +163,13 @@ sparkR.init <- function(
163163
sparkHome <- suppressWarnings(normalizePath(sparkHome))
164164
}
165165

166-
sparkEnvirMap <- new.env()
167-
for (varname in names(sparkEnvir)) {
168-
sparkEnvirMap[[varname]] <- sparkEnvir[[varname]]
169-
}
166+
sparkEnvirMap <- convertNamedListToEnv(sparkEnvir)
170167

171-
sparkExecutorEnvMap <- new.env()
172-
if (!any(names(sparkExecutorEnv) == "LD_LIBRARY_PATH")) {
168+
sparkExecutorEnvMap <- convertNamedListToEnv(sparkExecutorEnv)
169+
if(is.null(sparkExecutorEnvMap$LD_LIBRARY_PATH)) {
173170
sparkExecutorEnvMap[["LD_LIBRARY_PATH"]] <-
174171
paste0("$LD_LIBRARY_PATH:",Sys.getenv("LD_LIBRARY_PATH"))
175172
}
176-
for (varname in names(sparkExecutorEnv)) {
177-
sparkExecutorEnvMap[[varname]] <- sparkExecutorEnv[[varname]]
178-
}
179173

180174
nonEmptyJars <- Filter(function(x) { x != "" }, jars)
181175
localJarPaths <- lapply(nonEmptyJars,

R/pkg/R/stats.R

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,35 @@ setMethod("freqItems", signature(x = "DataFrame", cols = "character"),
127127
sct <- callJMethod(statFunctions, "freqItems", as.list(cols), support)
128128
collect(dataFrame(sct))
129129
})
130+
131+
#' sampleBy
132+
#'
133+
#' Returns a stratified sample without replacement based on the fraction given on each stratum.
134+
#'
135+
#' @param x A SparkSQL DataFrame
136+
#' @param col column that defines strata
137+
#' @param fractions A named list giving sampling fraction for each stratum. If a stratum is
138+
#' not specified, we treat its fraction as zero.
139+
#' @param seed random seed
140+
#' @return A new DataFrame that represents the stratified sample
141+
#'
142+
#' @rdname statfunctions
143+
#' @name sampleBy
144+
#' @export
145+
#' @examples
146+
#'\dontrun{
147+
#' df <- jsonFile(sqlContext, "/path/to/file.json")
148+
#' sample <- sampleBy(df, "key", fractions, 36)
149+
#' }
150+
setMethod("sampleBy",
151+
signature(x = "DataFrame", col = "character",
152+
fractions = "list", seed = "numeric"),
153+
function(x, col, fractions, seed) {
154+
fractionsEnv <- convertNamedListToEnv(fractions)
155+
156+
statFunctions <- callJMethod(x@sdf, "stat")
157+
# Seed is expected to be Long on Scala side, here convert it to an integer
158+
# due to SerDe limitation now.
159+
sdf <- callJMethod(statFunctions, "sampleBy", col, fractionsEnv, as.integer(seed))
160+
dataFrame(sdf)
161+
})

R/pkg/R/utils.R

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,3 +605,21 @@ structToList <- function(struct) {
605605
class(struct) <- "list"
606606
struct
607607
}
608+
609+
# Convert a named list to an environment to be passed to JVM
610+
convertNamedListToEnv <- function(namedList) {
611+
# Make sure each item in the list has a name
612+
names <- names(namedList)
613+
stopifnot(
614+
if (is.null(names)) {
615+
length(namedList) == 0
616+
} else {
617+
!any(is.na(names))
618+
})
619+
620+
env <- new.env()
621+
for (name in names) {
622+
env[[name]] <- namedList[[name]]
623+
}
624+
env
625+
}

R/pkg/inst/tests/test_sparkSQL.R

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,7 +1071,7 @@ test_that("join() and merge() on a DataFrame", {
10711071
expect_equal(names(joined2), c("age", "name", "name", "test"))
10721072
expect_equal(count(joined2), 3)
10731073

1074-
joined3 <- join(df, df2, df$name == df2$name, "right_outer")
1074+
joined3 <- join(df, df2, df$name == df2$name, "rightouter")
10751075
expect_equal(names(joined3), c("age", "name", "name", "test"))
10761076
expect_equal(count(joined3), 4)
10771077
expect_true(is.na(collect(orderBy(joined3, joined3$age))$age[2]))
@@ -1082,11 +1082,34 @@ test_that("join() and merge() on a DataFrame", {
10821082
expect_equal(count(joined4), 4)
10831083
expect_equal(collect(orderBy(joined4, joined4$name))$newAge[3], 24)
10841084

1085+
joined5 <- join(df, df2, df$name == df2$name, "leftouter")
1086+
expect_equal(names(joined5), c("age", "name", "name", "test"))
1087+
expect_equal(count(joined5), 3)
1088+
expect_true(is.na(collect(orderBy(joined5, joined5$age))$age[1]))
1089+
1090+
joined6 <- join(df, df2, df$name == df2$name, "inner")
1091+
expect_equal(names(joined6), c("age", "name", "name", "test"))
1092+
expect_equal(count(joined6), 3)
1093+
1094+
joined7 <- join(df, df2, df$name == df2$name, "leftsemi")
1095+
expect_equal(names(joined7), c("age", "name"))
1096+
expect_equal(count(joined7), 3)
1097+
1098+
joined8 <- join(df, df2, df$name == df2$name, "left_outer")
1099+
expect_equal(names(joined8), c("age", "name", "name", "test"))
1100+
expect_equal(count(joined8), 3)
1101+
expect_true(is.na(collect(orderBy(joined8, joined8$age))$age[1]))
1102+
1103+
joined9 <- join(df, df2, df$name == df2$name, "right_outer")
1104+
expect_equal(names(joined9), c("age", "name", "name", "test"))
1105+
expect_equal(count(joined9), 4)
1106+
expect_true(is.na(collect(orderBy(joined9, joined9$age))$age[2]))
1107+
10851108
merged <- select(merge(df, df2, df$name == df2$name, "outer"),
10861109
alias(df$age + 5, "newAge"), df$name, df2$test)
10871110
expect_equal(names(merged), c("newAge", "name", "test"))
10881111
expect_equal(count(merged), 4)
1089-
expect_equal(collect(orderBy(merged, joined4$name))$newAge[3], 24)
1112+
expect_equal(collect(orderBy(merged, merged$name))$newAge[3], 24)
10901113
})
10911114

10921115
test_that("toJSON() returns an RDD of the correct values", {
@@ -1393,6 +1416,16 @@ test_that("freqItems() on a DataFrame", {
13931416
expect_identical(result[[2]], list(list(-1, -99)))
13941417
})
13951418

1419+
test_that("sampleBy() on a DataFrame", {
1420+
l <- lapply(c(0:99), function(i) { as.character(i %% 3) })
1421+
df <- createDataFrame(sqlContext, l, "key")
1422+
fractions <- list("0" = 0.1, "1" = 0.2)
1423+
sample <- sampleBy(df, "key", fractions, 0)
1424+
result <- collect(orderBy(count(groupBy(sample, "key")), "key"))
1425+
expect_identical(as.list(result[1, ]), list(key = "0", count = 2))
1426+
expect_identical(as.list(result[2, ]), list(key = "1", count = 10))
1427+
})
1428+
13961429
test_that("SQL error message is returned from JVM", {
13971430
retError <- tryCatch(sql(sqlContext, "select * from blah"), error = function(e) e)
13981431
expect_equal(grepl("Table Not Found: blah", retError), TRUE)

build/mvn

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,8 @@ install_scala() {
104104
"scala-${scala_version}.tgz" \
105105
"scala-${scala_version}/bin/scala"
106106

107-
SCALA_COMPILER="$(cd "$(dirname ${scala_bin})/../lib" && pwd)/scala-compiler.jar"
108-
SCALA_LIBRARY="$(cd "$(dirname ${scala_bin})/../lib" && pwd)/scala-library.jar"
107+
SCALA_COMPILER="$(cd "$(dirname "${scala_bin}")/../lib" && pwd)/scala-compiler.jar"
108+
SCALA_LIBRARY="$(cd "$(dirname "${scala_bin}")/../lib" && pwd)/scala-library.jar"
109109
}
110110

111111
# Setup healthy defaults for the Zinc port if none were provided from
@@ -135,10 +135,10 @@ cd "${_CALLING_DIR}"
135135

136136
# Now that zinc is ensured to be installed, check its status and, if its
137137
# not running or just installed, start it
138-
if [ -n "${ZINC_INSTALL_FLAG}" -o -z "`${ZINC_BIN} -status -port ${ZINC_PORT}`" ]; then
138+
if [ -n "${ZINC_INSTALL_FLAG}" -o -z "`"${ZINC_BIN}" -status -port ${ZINC_PORT}`" ]; then
139139
export ZINC_OPTS=${ZINC_OPTS:-"$_COMPILE_JVM_OPTS"}
140-
${ZINC_BIN} -shutdown -port ${ZINC_PORT}
141-
${ZINC_BIN} -start -port ${ZINC_PORT} \
140+
"${ZINC_BIN}" -shutdown -port ${ZINC_PORT}
141+
"${ZINC_BIN}" -start -port ${ZINC_PORT} \
142142
-scala-compiler "${SCALA_COMPILER}" \
143143
-scala-library "${SCALA_LIBRARY}" &>/dev/null
144144
fi

core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -93,15 +93,6 @@ private[spark] abstract class RpcEnv(conf: SparkConf) {
9393
defaultLookupTimeout.awaitResult(asyncSetupEndpointRefByURI(uri))
9494
}
9595

96-
/**
97-
* Retrieve the [[RpcEndpointRef]] represented by `systemName`, `address` and `endpointName`
98-
* asynchronously.
99-
*/
100-
def asyncSetupEndpointRef(
101-
systemName: String, address: RpcAddress, endpointName: String): Future[RpcEndpointRef] = {
102-
asyncSetupEndpointRefByURI(uriOf(systemName, address, endpointName))
103-
}
104-
10596
/**
10697
* Retrieve the [[RpcEndpointRef]] represented by `systemName`, `address` and `endpointName`.
10798
* This is a blocking action.

core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ import org.apache.spark.network.client.RpcResponseCallback
2929
import org.apache.spark.rpc._
3030
import org.apache.spark.util.ThreadUtils
3131

32+
/**
33+
* A message dispatcher, responsible for routing RPC messages to the appropriate endpoint(s).
34+
*/
3235
private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
3336

3437
private class EndpointData(
@@ -42,7 +45,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
4245
private val endpointRefs = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]
4346

4447
// Track the receivers whose inboxes may contain messages.
45-
private val receivers = new LinkedBlockingQueue[EndpointData]()
48+
private val receivers = new LinkedBlockingQueue[EndpointData]
4649

4750
/**
4851
* True if the dispatcher has been stopped. Once stopped, all messages posted will be bounced
@@ -52,7 +55,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
5255
private var stopped = false
5356

5457
def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = {
55-
val addr = new NettyRpcAddress(nettyEnv.address.host, nettyEnv.address.port, name)
58+
val addr = new RpcEndpointAddress(nettyEnv.address.host, nettyEnv.address.port, name)
5659
val endpointRef = new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv)
5760
synchronized {
5861
if (stopped) {

0 commit comments

Comments
 (0)