diff --git a/.gitignore b/.gitignore
index 08f2d8f7543f..07524bc429e9 100644
--- a/.gitignore
+++ b/.gitignore
@@ -50,6 +50,7 @@ spark-tests.log
streaming-tests.log
dependency-reduced-pom.xml
.ensime
+.ensime_cache/
.ensime_lucene
checkpoint
derby.log
diff --git a/R/install-dev.bat b/R/install-dev.bat
index 008a5c668bc4..ed1c91ae3a0f 100644
--- a/R/install-dev.bat
+++ b/R/install-dev.bat
@@ -25,3 +25,9 @@ set SPARK_HOME=%~dp0..
MKDIR %SPARK_HOME%\R\lib
R.exe CMD INSTALL --library="%SPARK_HOME%\R\lib" %SPARK_HOME%\R\pkg\
+
+rem Zip the SparkR package so that it can be distributed to worker nodes on YARN
+pushd %SPARK_HOME%\R\lib
+%JAVA_HOME%\bin\jar.exe cfM "%SPARK_HOME%\R\lib\sparkr.zip" SparkR
+popd
+
diff --git a/R/install-dev.sh b/R/install-dev.sh
index 59d98c9c7a64..4972bb921707 100755
--- a/R/install-dev.sh
+++ b/R/install-dev.sh
@@ -42,4 +42,8 @@ Rscript -e ' if("devtools" %in% rownames(installed.packages())) { library(devtoo
# Install SparkR to $LIB_DIR
R CMD INSTALL --library=$LIB_DIR $FWDIR/pkg/
+# Zip the SparkR package so that it can be distributed to worker nodes on YARN
+cd $LIB_DIR
+jar cfM "$LIB_DIR/sparkr.zip" SparkR
+
popd > /dev/null
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 2ee7d6f94f1b..260c9edce62e 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -98,6 +98,7 @@ exportMethods("%in%",
"add_months",
"alias",
"approxCountDistinct",
+ "array_contains",
"asc",
"ascii",
"asin",
@@ -215,6 +216,7 @@ exportMethods("%in%",
"sinh",
"size",
"skewness",
+ "sort_array",
"soundex",
"stddev",
"stddev_pop",
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index fd105ba5bc9b..8a13e7a36766 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -254,7 +254,6 @@ setMethod("dtypes",
#' @family DataFrame functions
#' @rdname columns
#' @name columns
-#' @aliases names
#' @export
#' @examples
#'\dontrun{
@@ -272,7 +271,6 @@ setMethod("columns",
})
})
-#' @family DataFrame functions
#' @rdname columns
#' @name names
setMethod("names",
@@ -281,7 +279,6 @@ setMethod("names",
columns(x)
})
-#' @family DataFrame functions
#' @rdname columns
#' @name names<-
setMethod("names<-",
@@ -533,14 +530,8 @@ setMethod("distinct",
dataFrame(sdf)
})
-#' @title Distinct rows in a DataFrame
-#
-#' @description Returns a new DataFrame containing distinct rows in this DataFrame
-#'
-#' @family DataFrame functions
-#' @rdname unique
+#' @rdname distinct
#' @name unique
-#' @aliases distinct
setMethod("unique",
signature(x = "DataFrame"),
function(x) {
@@ -557,7 +548,7 @@ setMethod("unique",
#'
#' @family DataFrame functions
#' @rdname sample
-#' @aliases sample_frac
+#' @name sample
#' @export
#' @examples
#'\dontrun{
@@ -579,7 +570,6 @@ setMethod("sample",
dataFrame(sdf)
})
-#' @family DataFrame functions
#' @rdname sample
#' @name sample_frac
setMethod("sample_frac",
@@ -589,16 +579,15 @@ setMethod("sample_frac",
sample(x, withReplacement, fraction)
})
-#' Count
+#' nrow
#'
#' Returns the number of rows in a DataFrame
#'
#' @param x A SparkSQL DataFrame
#'
#' @family DataFrame functions
-#' @rdname count
+#' @rdname nrow
#' @name count
-#' @aliases nrow
#' @export
#' @examples
#'\dontrun{
@@ -614,14 +603,8 @@ setMethod("count",
callJMethod(x@sdf, "count")
})
-#' @title Number of rows for a DataFrame
-#' @description Returns number of rows in a DataFrames
-#'
#' @name nrow
-#'
-#' @family DataFrame functions
#' @rdname nrow
-#' @aliases count
setMethod("nrow",
signature(x = "DataFrame"),
function(x) {
@@ -870,7 +853,6 @@ setMethod("toRDD",
#' @param x a DataFrame
#' @return a GroupedData
#' @seealso GroupedData
-#' @aliases group_by
#' @family DataFrame functions
#' @rdname groupBy
#' @name groupBy
@@ -896,7 +878,6 @@ setMethod("groupBy",
groupedData(sgd)
})
-#' @family DataFrame functions
#' @rdname groupBy
#' @name group_by
setMethod("group_by",
@@ -913,7 +894,6 @@ setMethod("group_by",
#' @family DataFrame functions
#' @rdname agg
#' @name agg
-#' @aliases summarize
#' @export
setMethod("agg",
signature(x = "DataFrame"),
@@ -921,7 +901,6 @@ setMethod("agg",
agg(groupBy(x), ...)
})
-#' @family DataFrame functions
#' @rdname agg
#' @name summarize
setMethod("summarize",
@@ -1092,7 +1071,6 @@ setMethod("[", signature(x = "DataFrame", i = "Column"),
#' @family DataFrame functions
#' @rdname subset
#' @name subset
-#' @aliases [
#' @family subsetting functions
#' @examples
#' \dontrun{
@@ -1216,7 +1194,7 @@ setMethod("selectExpr",
#' @family DataFrame functions
#' @rdname withColumn
#' @name withColumn
-#' @aliases mutate transform
+#' @seealso \link{rename} \link{mutate}
#' @export
#' @examples
#'\dontrun{
@@ -1231,7 +1209,6 @@ setMethod("withColumn",
function(x, colName, col) {
select(x, x$"*", alias(col, colName))
})
-
#' Mutate
#'
#' Return a new DataFrame with the specified columns added.
@@ -1240,9 +1217,9 @@ setMethod("withColumn",
#' @param col a named argument of the form name = col
#' @return A new DataFrame with the new columns added.
#' @family DataFrame functions
-#' @rdname withColumn
+#' @rdname mutate
#' @name mutate
-#' @aliases withColumn transform
+#' @seealso \link{rename} \link{withColumn}
#' @export
#' @examples
#'\dontrun{
@@ -1273,17 +1250,15 @@ setMethod("mutate",
})
#' @export
-#' @family DataFrame functions
-#' @rdname withColumn
+#' @rdname mutate
#' @name transform
-#' @aliases withColumn mutate
setMethod("transform",
signature(`_data` = "DataFrame"),
function(`_data`, ...) {
mutate(`_data`, ...)
})
-#' WithColumnRenamed
+#' rename
#'
#' Rename an existing column in a DataFrame.
#'
@@ -1292,8 +1267,9 @@ setMethod("transform",
#' @param newCol The new column name.
#' @return A DataFrame with the column name changed.
#' @family DataFrame functions
-#' @rdname withColumnRenamed
+#' @rdname rename
#' @name withColumnRenamed
+#' @seealso \link{mutate}
#' @export
#' @examples
#'\dontrun{
@@ -1316,17 +1292,9 @@ setMethod("withColumnRenamed",
select(x, cols)
})
-#' Rename
-#'
-#' Rename an existing column in a DataFrame.
-#'
-#' @param x A DataFrame
-#' @param newCol A named pair of the form new_column_name = existing_column
-#' @return A DataFrame with the column name changed.
-#' @family DataFrame functions
-#' @rdname withColumnRenamed
+#' @param newColPair A named pair of the form new_column_name = existing_column
+#' @rdname rename
#' @name rename
-#' @aliases withColumnRenamed
#' @export
#' @examples
#'\dontrun{
@@ -1371,7 +1339,6 @@ setClassUnion("characterOrColumn", c("character", "Column"))
#' @family DataFrame functions
#' @rdname arrange
#' @name arrange
-#' @aliases orderby
#' @export
#' @examples
#'\dontrun{
@@ -1395,8 +1362,8 @@ setMethod("arrange",
dataFrame(sdf)
})
-#' @family DataFrame functions
#' @rdname arrange
+#' @name arrange
#' @export
setMethod("arrange",
signature(x = "DataFrame", col = "character"),
@@ -1427,9 +1394,9 @@ setMethod("arrange",
do.call("arrange", c(x, jcols))
})
-#' @family DataFrame functions
#' @rdname arrange
-#' @name orderby
+#' @name orderBy
+#' @export
setMethod("orderBy",
signature(x = "DataFrame", col = "characterOrColumn"),
function(x, col) {
@@ -1492,6 +1459,7 @@ setMethod("where",
#' @family DataFrame functions
#' @rdname join
#' @name join
+#' @seealso \link{merge}
#' @export
#' @examples
#'\dontrun{
@@ -1528,9 +1496,7 @@ setMethod("join",
dataFrame(sdf)
})
-#'
#' @name merge
-#' @aliases join
#' @title Merges two data frames
#' @param x the first data frame to be joined
#' @param y the second data frame to be joined
@@ -1550,6 +1516,7 @@ setMethod("join",
#' outer join will be returned.
#' @family DataFrame functions
#' @rdname merge
+#' @seealso \link{join}
#' @export
#' @examples
#'\dontrun{
@@ -1671,7 +1638,7 @@ generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) {
cols
}
-#' UnionAll
+#' rbind
#'
#' Return a new DataFrame containing the union of rows in this DataFrame
#' and another DataFrame. This is equivalent to `UNION ALL` in SQL.
@@ -1681,7 +1648,7 @@ generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) {
#' @param y A Spark DataFrame
#' @return A DataFrame containing the result of the union.
#' @family DataFrame functions
-#' @rdname unionAll
+#' @rdname rbind
#' @name unionAll
#' @export
#' @examples
@@ -1700,13 +1667,11 @@ setMethod("unionAll",
})
#' @title Union two or more DataFrames
-#'
#' @description Returns a new DataFrame containing rows of all parameters.
#'
-#' @family DataFrame functions
#' @rdname rbind
#' @name rbind
-#' @aliases unionAll
+#' @export
setMethod("rbind",
signature(... = "DataFrame"),
function(x, ..., deparse.level = 1) {
@@ -1795,7 +1760,6 @@ setMethod("except",
#' @family DataFrame functions
#' @rdname write.df
#' @name write.df
-#' @aliases saveDF
#' @export
#' @examples
#'\dontrun{
@@ -1828,7 +1792,6 @@ setMethod("write.df",
callJMethod(df@sdf, "save", source, jmode, options)
})
-#' @family DataFrame functions
#' @rdname write.df
#' @name saveDF
#' @export
@@ -1891,7 +1854,7 @@ setMethod("saveAsTable",
callJMethod(df@sdf, "saveAsTable", tableName, source, jmode, options)
})
-#' describe
+#' summary
#'
#' Computes statistics for numeric columns.
#' If no columns are given, this function computes statistics for all numerical columns.
@@ -1901,9 +1864,8 @@ setMethod("saveAsTable",
#' @param ... Additional expressions
#' @return A DataFrame
#' @family DataFrame functions
-#' @rdname describe
+#' @rdname summary
#' @name describe
-#' @aliases summary
#' @export
#' @examples
#'\dontrun{
@@ -1923,8 +1885,7 @@ setMethod("describe",
dataFrame(sdf)
})
-#' @family DataFrame functions
-#' @rdname describe
+#' @rdname summary
#' @name describe
setMethod("describe",
signature(x = "DataFrame"),
@@ -1934,11 +1895,6 @@ setMethod("describe",
dataFrame(sdf)
})
-#' @title Summary
-#'
-#' @description Computes statistics for numeric columns of the DataFrame
-#'
-#' @family DataFrame functions
#' @rdname summary
#' @name summary
setMethod("summary",
@@ -1966,7 +1922,6 @@ setMethod("summary",
#' @family DataFrame functions
#' @rdname nafunctions
#' @name dropna
-#' @aliases na.omit
#' @export
#' @examples
#'\dontrun{
@@ -1993,7 +1948,6 @@ setMethod("dropna",
dataFrame(sdf)
})
-#' @family DataFrame functions
#' @rdname nafunctions
#' @name na.omit
#' @export
@@ -2019,9 +1973,7 @@ setMethod("na.omit",
#' type are ignored. For example, if value is a character, and
#' subset contains a non-character column, then the non-character
#' column is simply ignored.
-#' @return A DataFrame
#'
-#' @family DataFrame functions
#' @rdname nafunctions
#' @name fillna
#' @export
@@ -2152,7 +2104,7 @@ setMethod("with",
})
#' Returns the column types of a DataFrame.
-#'
+#'
#' @name coltypes
#' @title Get column types of a DataFrame
#' @family dataframe_funcs
@@ -2198,4 +2150,4 @@ setMethod("coltypes",
rTypes[naIndices] <- types[naIndices]
rTypes
- })
\ No newline at end of file
+ })
diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R
index fd013fdb304d..a62b25fde926 100644
--- a/R/pkg/R/SQLContext.R
+++ b/R/pkg/R/SQLContext.R
@@ -17,27 +17,33 @@
# SQLcontext.R: SQLContext-driven functions
+
+# Map top level R type to SQL type
+getInternalType <- function(x) {
+ # class of POSIXlt is c("POSIXlt" "POSIXt")
+ switch(class(x)[[1]],
+ integer = "integer",
+ character = "string",
+ logical = "boolean",
+ double = "double",
+ numeric = "double",
+ raw = "binary",
+ list = "array",
+ struct = "struct",
+ environment = "map",
+ Date = "date",
+ POSIXlt = "timestamp",
+ POSIXct = "timestamp",
+ stop(paste("Unsupported type for DataFrame:", class(x))))
+}
+
#' infer the SQL type
infer_type <- function(x) {
if (is.null(x)) {
stop("can not infer type from NULL")
}
- # class of POSIXlt is c("POSIXlt" "POSIXt")
- type <- switch(class(x)[[1]],
- integer = "integer",
- character = "string",
- logical = "boolean",
- double = "double",
- numeric = "double",
- raw = "binary",
- list = "array",
- struct = "struct",
- environment = "map",
- Date = "date",
- POSIXlt = "timestamp",
- POSIXct = "timestamp",
- stop(paste("Unsupported type for DataFrame:", class(x))))
+ type <- getInternalType(x)
if (type == "map") {
stopifnot(length(x) > 0)
@@ -90,19 +96,25 @@ createDataFrame <- function(sqlContext, data, schema = NULL, samplingRatio = 1.0
if (is.null(schema)) {
schema <- names(data)
}
- n <- nrow(data)
- m <- ncol(data)
+
# get rid of factor type
- dropFactor <- function(x) {
+ cleanCols <- function(x) {
if (is.factor(x)) {
as.character(x)
} else {
x
}
}
- data <- lapply(1:n, function(i) {
- lapply(1:m, function(j) { dropFactor(data[i,j]) })
- })
+
+ # drop factors and wrap lists
+ data <- setNames(lapply(data, cleanCols), NULL)
+
+ # check if all columns have supported type
+ lapply(data, getInternalType)
+
+ # convert to rows
+ args <- list(FUN = list, SIMPLIFY = FALSE, USE.NAMES = FALSE)
+ data <- do.call(mapply, append(args, data))
}
if (is.list(data)) {
sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sqlContext)
diff --git a/R/pkg/R/broadcast.R b/R/pkg/R/broadcast.R
index 2403925b267c..38f0eed95e06 100644
--- a/R/pkg/R/broadcast.R
+++ b/R/pkg/R/broadcast.R
@@ -51,7 +51,6 @@ Broadcast <- function(id, value, jBroadcastRef, objName) {
#
# @param bcast The broadcast variable to get
# @rdname broadcast
-# @aliases value,Broadcast-method
setMethod("value",
signature(bcast = "Broadcast"),
function(bcast) {
diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R
index 3d0255a62f15..25a1f2210149 100644
--- a/R/pkg/R/functions.R
+++ b/R/pkg/R/functions.R
@@ -373,22 +373,6 @@ setMethod("exp",
column(jc)
})
-#' explode
-#'
-#' Creates a new row for each element in the given array or map column.
-#'
-#' @rdname explode
-#' @name explode
-#' @family collection_funcs
-#' @export
-#' @examples \dontrun{explode(df$c)}
-setMethod("explode",
- signature(x = "Column"),
- function(x) {
- jc <- callJStatic("org.apache.spark.sql.functions", "explode", x@jc)
- column(jc)
- })
-
#' expm1
#'
#' Computes the exponential of the given value minus one.
@@ -980,22 +964,6 @@ setMethod("sinh",
column(jc)
})
-#' size
-#'
-#' Returns length of array or map.
-#'
-#' @rdname size
-#' @name size
-#' @family collection_funcs
-#' @export
-#' @examples \dontrun{size(df$c)}
-setMethod("size",
- signature(x = "Column"),
- function(x) {
- jc <- callJStatic("org.apache.spark.sql.functions", "size", x@jc)
- column(jc)
- })
-
#' skewness
#'
#' Aggregate function: returns the skewness of the values in a group.
@@ -2236,7 +2204,7 @@ setMethod("denseRank",
#' @export
#' @examples \dontrun{lag(df$c)}
setMethod("lag",
- signature(x = "characterOrColumn", offset = "numeric", defaultValue = "ANY"),
+ signature(x = "characterOrColumn"),
function(x, offset, defaultValue = NULL) {
col <- if (class(x) == "Column") {
x@jc
@@ -2365,3 +2333,80 @@ setMethod("rowNumber",
jc <- callJStatic("org.apache.spark.sql.functions", "rowNumber")
column(jc)
})
+
+###################### Collection functions######################
+
+#' array_contains
+#'
+#' Returns true if the array contain the value.
+#'
+#' @param x A Column
+#' @param value A value to be checked if contained in the column
+#' @rdname array_contains
+#' @name array_contains
+#' @family collection_funcs
+#' @export
+#' @examples \dontrun{array_contains(df$c, 1)}
+setMethod("array_contains",
+ signature(x = "Column", value = "ANY"),
+ function(x, value) {
+ jc <- callJStatic("org.apache.spark.sql.functions", "array_contains", x@jc, value)
+ column(jc)
+ })
+
+#' explode
+#'
+#' Creates a new row for each element in the given array or map column.
+#'
+#' @rdname explode
+#' @name explode
+#' @family collection_funcs
+#' @export
+#' @examples \dontrun{explode(df$c)}
+setMethod("explode",
+ signature(x = "Column"),
+ function(x) {
+ jc <- callJStatic("org.apache.spark.sql.functions", "explode", x@jc)
+ column(jc)
+ })
+
+#' size
+#'
+#' Returns length of array or map.
+#'
+#' @rdname size
+#' @name size
+#' @family collection_funcs
+#' @export
+#' @examples \dontrun{size(df$c)}
+setMethod("size",
+ signature(x = "Column"),
+ function(x) {
+ jc <- callJStatic("org.apache.spark.sql.functions", "size", x@jc)
+ column(jc)
+ })
+
+#' sort_array
+#'
+#' Sorts the input array for the given column in ascending order,
+#' according to the natural ordering of the array elements.
+#'
+#' @param x A Column to sort
+#' @param asc A logical flag indicating the sorting order.
+#' TRUE, sorting is in ascending order.
+#' FALSE, sorting is in descending order.
+#' @rdname sort_array
+#' @name sort_array
+#' @family collection_funcs
+#' @export
+#' @examples
+#' \dontrun{
+#' sort_array(df$c)
+#' sort_array(df$c, FALSE)
+#' }
+setMethod("sort_array",
+ signature(x = "Column"),
+ function(x, asc = TRUE) {
+ jc <- callJStatic("org.apache.spark.sql.functions", "sort_array", x@jc, asc)
+ column(jc)
+ })
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 612e639f8ad9..1b3f10ea0464 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -397,7 +397,7 @@ setGeneric("cov", function(x, col1, col2) {standardGeneric("cov") })
#' @export
setGeneric("corr", function(x, col1, col2, method = "pearson") {standardGeneric("corr") })
-#' @rdname describe
+#' @rdname summary
#' @export
setGeneric("describe", function(x, col, ...) { standardGeneric("describe") })
@@ -459,11 +459,11 @@ setGeneric("isLocal", function(x) { standardGeneric("isLocal") })
#' @export
setGeneric("limit", function(x, num) {standardGeneric("limit") })
-#' rdname merge
+#' @rdname merge
#' @export
setGeneric("merge")
-#' @rdname withColumn
+#' @rdname mutate
#' @export
setGeneric("mutate", function(.data, ...) {standardGeneric("mutate") })
@@ -475,7 +475,7 @@ setGeneric("orderBy", function(x, col) { standardGeneric("orderBy") })
#' @export
setGeneric("printSchema", function(x) { standardGeneric("printSchema") })
-#' @rdname withColumnRenamed
+#' @rdname rename
#' @export
setGeneric("rename", function(x, ...) { standardGeneric("rename") })
@@ -539,7 +539,7 @@ setGeneric("showDF", function(x,...) { standardGeneric("showDF") })
# @rdname subset
# @export
-setGeneric("subset", function(x, subset, select, ...) { standardGeneric("subset") })
+setGeneric("subset", function(x, ...) { standardGeneric("subset") })
#' @rdname agg
#' @export
@@ -553,7 +553,7 @@ setGeneric("toJSON", function(x) { standardGeneric("toJSON") })
setGeneric("toRDD", function(x) { standardGeneric("toRDD") })
-#' @rdname unionAll
+#' @rdname rbind
#' @export
setGeneric("unionAll", function(x, y) { standardGeneric("unionAll") })
@@ -565,7 +565,7 @@ setGeneric("where", function(x, condition) { standardGeneric("where") })
#' @export
setGeneric("withColumn", function(x, colName, col) { standardGeneric("withColumn") })
-#' @rdname withColumnRenamed
+#' @rdname rename
#' @export
setGeneric("withColumnRenamed",
function(x, existingCol, newCol) { standardGeneric("withColumnRenamed") })
@@ -644,6 +644,10 @@ setGeneric("add_months", function(y, x) { standardGeneric("add_months") })
#' @export
setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCountDistinct") })
+#' @rdname array_contains
+#' @export
+setGeneric("array_contains", function(x, value) { standardGeneric("array_contains") })
+
#' @rdname ascii
#' @export
setGeneric("ascii", function(x) { standardGeneric("ascii") })
@@ -786,7 +790,7 @@ setGeneric("kurtosis", function(x) { standardGeneric("kurtosis") })
#' @rdname lag
#' @export
-setGeneric("lag", function(x, offset, defaultValue = NULL) { standardGeneric("lag") })
+setGeneric("lag", function(x, ...) { standardGeneric("lag") })
#' @rdname last
#' @export
@@ -961,6 +965,10 @@ setGeneric("size", function(x) { standardGeneric("size") })
#' @export
setGeneric("skewness", function(x) { standardGeneric("skewness") })
+#' @rdname sort_array
+#' @export
+setGeneric("sort_array", function(x, asc = TRUE) { standardGeneric("sort_array") })
+
#' @rdname soundex
#' @export
setGeneric("soundex", function(x) { standardGeneric("soundex") })
@@ -1054,6 +1062,10 @@ setGeneric("year", function(x) { standardGeneric("year") })
#' @export
setGeneric("glm")
+#' @rdname predict
+#' @export
+setGeneric("predict", function(object, ...) { standardGeneric("predict") })
+
#' @rdname rbind
#' @export
setGeneric("rbind", signature = "...")
@@ -1072,4 +1084,4 @@ setGeneric("with")
#' @rdname coltypes
#' @export
-setGeneric("coltypes", function(x) { standardGeneric("coltypes") })
\ No newline at end of file
+setGeneric("coltypes", function(x) { standardGeneric("coltypes") })
diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R
index e5f702faee65..23b49aebda05 100644
--- a/R/pkg/R/group.R
+++ b/R/pkg/R/group.R
@@ -68,7 +68,7 @@ setMethod("count",
dataFrame(callJMethod(x@sgd, "count"))
})
-#' Agg
+#' summarize
#'
#' Aggregates on the entire DataFrame without groups.
#' The resulting DataFrame will also contain the grouping columns.
@@ -78,12 +78,14 @@ setMethod("count",
#'
#' @param x a GroupedData
#' @return a DataFrame
-#' @rdname agg
+#' @rdname summarize
+#' @name agg
#' @family agg_funcs
#' @examples
#' \dontrun{
#' df2 <- agg(df, age = "sum") # new column name will be created as 'SUM(age#0)'
-#' df2 <- agg(df, ageSum = sum(df$age)) # Creates a new column named ageSum
+#' df3 <- agg(df, ageSum = sum(df$age)) # Creates a new column named ageSum
+#' df4 <- summarize(df, ageSum = max(df$age))
#' }
setMethod("agg",
signature(x = "GroupedData"),
@@ -110,8 +112,8 @@ setMethod("agg",
dataFrame(sdf)
})
-#' @rdname agg
-#' @aliases agg
+#' @rdname summarize
+#' @name summarize
setMethod("summarize",
signature(x = "GroupedData"),
function(x, ...) {
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index f23e1c7f1fce..8d3b4388ae57 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -32,6 +32,12 @@ setClass("PipelineModel", representation(model = "jobj"))
#' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg.
#' @param lambda Regularization parameter
#' @param alpha Elastic-net mixing parameter (see glmnet's documentation for details)
+#' @param standardize Whether to standardize features before training
+#' @param solver The solver algorithm used for optimization, this can be "l-bfgs", "normal" and
+#' "auto". "l-bfgs" denotes Limited-memory BFGS which is a limited-memory
+#' quasi-Newton optimization method. "normal" denotes using Normal Equation as an
+#' analytical solution to the linear regression problem. The default value is "auto"
+#' which means that the solver algorithm is selected automatically.
#' @return a fitted MLlib model
#' @rdname glm
#' @export
@@ -79,9 +85,15 @@ setMethod("predict", signature(object = "PipelineModel"),
#'
#' Returns the summary of a model produced by glm(), similarly to R's summary().
#'
-#' @param x A fitted MLlib model
-#' @return a list with a 'coefficient' component, which is the matrix of coefficients. See
-#' summary.glm for more information.
+#' @param object A fitted MLlib model
+#' @return a list with 'devianceResiduals' and 'coefficients' components for gaussian family
+#' or a list with 'coefficients' component for binomial family. \cr
+#' For gaussian family: the 'devianceResiduals' gives the min/max deviance residuals
+#' of the estimation, the 'coefficients' gives the estimated coefficients and their
+#' estimated standard errors, t values and p-values. (It only available when model
+#' fitted by normal solver.) \cr
+#' For binomial family: the 'coefficients' gives the estimated coefficients.
+#' See summary.glm for more information. \cr
#' @rdname summary
#' @export
#' @examples
diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R
index ebe2b2b8dc1d..7ff3fa628b9c 100644
--- a/R/pkg/R/sparkR.R
+++ b/R/pkg/R/sparkR.R
@@ -48,6 +48,12 @@ sparkR.stop <- function() {
}
}
+ # Remove the R package lib path from .libPaths()
+ if (exists(".libPath", envir = env)) {
+ libPath <- get(".libPath", envir = env)
+ .libPaths(.libPaths()[.libPaths() != libPath])
+ }
+
if (exists(".backendLaunched", envir = env)) {
callJStatic("SparkRHandler", "stopBackend")
}
@@ -155,14 +161,20 @@ sparkR.init <- function(
f <- file(path, open="rb")
backendPort <- readInt(f)
monitorPort <- readInt(f)
+ rLibPath <- readString(f)
close(f)
file.remove(path)
if (length(backendPort) == 0 || backendPort == 0 ||
- length(monitorPort) == 0 || monitorPort == 0) {
+ length(monitorPort) == 0 || monitorPort == 0 ||
+ length(rLibPath) != 1) {
stop("JVM failed to launch")
}
assign(".monitorConn", socketConnection(port = monitorPort), envir = .sparkREnv)
assign(".backendLaunched", 1, envir = .sparkREnv)
+ if (rLibPath != "") {
+ assign(".libPath", rLibPath, envir = .sparkREnv)
+ .libPaths(c(rLibPath, .libPaths()))
+ }
}
.sparkREnv$backendPort <- backendPort
diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R
index db3b2c4bbd79..45c77a86c958 100644
--- a/R/pkg/R/utils.R
+++ b/R/pkg/R/utils.R
@@ -635,4 +635,4 @@ assignNewEnv <- function(data) {
assign(x = cols[i], value = data[, cols[i]], envir = env)
}
env
-}
\ No newline at end of file
+}
diff --git a/R/pkg/inst/profile/general.R b/R/pkg/inst/profile/general.R
index 2a8a8213d084..c55fe9ba7af7 100644
--- a/R/pkg/inst/profile/general.R
+++ b/R/pkg/inst/profile/general.R
@@ -17,6 +17,7 @@
.First <- function() {
packageDir <- Sys.getenv("SPARKR_PACKAGE_DIR")
- .libPaths(c(packageDir, .libPaths()))
+ dirs <- strsplit(packageDir, ",")[[1]]
+ .libPaths(c(dirs, .libPaths()))
Sys.setenv(NOAWT=1)
}
diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R
index d497ad8c9daa..e0667e5e22c1 100644
--- a/R/pkg/inst/tests/test_mllib.R
+++ b/R/pkg/inst/tests/test_mllib.R
@@ -31,6 +31,11 @@ test_that("glm and predict", {
model <- glm(Sepal_Width ~ Sepal_Length, training, family = "gaussian")
prediction <- predict(model, test)
expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double")
+
+ # Test stats::predict is working
+ x <- rnorm(15)
+ y <- x + rnorm(15)
+ expect_equal(length(predict(lm(y ~ x))), 15)
})
test_that("glm should work with long formula", {
diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R
index af024e6183a3..3f4f319fe745 100644
--- a/R/pkg/inst/tests/test_sparkSQL.R
+++ b/R/pkg/inst/tests/test_sparkSQL.R
@@ -229,7 +229,7 @@ test_that("create DataFrame from list or data.frame", {
df <- createDataFrame(sqlContext, l, c("a", "b"))
expect_equal(columns(df), c("a", "b"))
- l <- list(list(a=1, b=2), list(a=3, b=4))
+ l <- list(list(a = 1, b = 2), list(a = 3, b = 4))
df <- createDataFrame(sqlContext, l)
expect_equal(columns(df), c("a", "b"))
@@ -242,6 +242,14 @@ test_that("create DataFrame from list or data.frame", {
expect_equal(count(df), 3)
ldf2 <- collect(df)
expect_equal(ldf$a, ldf2$a)
+
+ irisdf <- createDataFrame(sqlContext, iris)
+ iris_collected <- collect(irisdf)
+ expect_equivalent(iris_collected[,-5], iris[,-5])
+ expect_equal(iris_collected$Species, as.character(iris$Species))
+
+ mtcarsdf <- createDataFrame(sqlContext, mtcars)
+ expect_equivalent(collect(mtcarsdf), mtcars)
})
test_that("create DataFrame with different data types", {
@@ -283,6 +291,18 @@ test_that("create DataFrame with complex types", {
expect_equal(s$b, 3L)
})
+test_that("create DataFrame from a data.frame with complex types", {
+ ldf <- data.frame(row.names = 1:2)
+ ldf$a_list <- list(list(1, 2), list(3, 4))
+ ldf$an_envir <- c(as.environment(list(a = 1, b = 2)), as.environment(list(c = 3)))
+
+ sdf <- createDataFrame(sqlContext, ldf)
+ collected <- collect(sdf)
+
+ expect_identical(ldf[, 1, FALSE], collected[, 1, FALSE])
+ expect_equal(ldf$an_envir, collected$an_envir)
+})
+
# For test map type and struct type in DataFrame
mockLinesMapType <- c("{\"name\":\"Bob\",\"info\":{\"age\":16,\"height\":176.5}}",
"{\"name\":\"Alice\",\"info\":{\"age\":20,\"height\":164.3}}",
@@ -413,6 +433,10 @@ test_that("table() returns a new DataFrame", {
expect_is(tabledf, "DataFrame")
expect_equal(count(tabledf), 3)
dropTempTable(sqlContext, "table1")
+
+ # Test base::table is working
+ #a <- letters[1:3]
+ #expect_equal(class(table(a, sample(a))), "table")
})
test_that("toRDD() returns an RRDD", {
@@ -653,6 +677,9 @@ test_that("sample on a DataFrame", {
# Also test sample_frac
sampled3 <- sample_frac(df, FALSE, 0.1, 0) # set seed for predictable result
expect_true(count(sampled3) < 3)
+
+ # Test base::sample is working
+ #expect_equal(length(sample(1:12)), 12)
})
test_that("select operators", {
@@ -733,6 +760,9 @@ test_that("subsetting", {
df6 <- subset(df, df$age %in% c(30), c(1,2))
expect_equal(count(df6), 1)
expect_equal(columns(df6), c("name", "age"))
+
+ # Test base::subset is working
+ expect_equal(nrow(subset(airquality, Temp > 80, select = c(Ozone, Temp))), 68)
})
test_that("selectExpr() on a DataFrame", {
@@ -858,6 +888,19 @@ test_that("column functions", {
df4 <- createDataFrame(sqlContext, list(list(a = "010101")))
expect_equal(collect(select(df4, conv(df4$a, 2, 16)))[1, 1], "15")
+
+ # Test array_contains() and sort_array()
+ df <- createDataFrame(sqlContext, list(list(list(1L, 2L, 3L)), list(list(6L, 5L, 4L))))
+ result <- collect(select(df, array_contains(df[[1]], 1L)))[[1]]
+ expect_equal(result, c(TRUE, FALSE))
+
+ result <- collect(select(df, sort_array(df[[1]], FALSE)))[[1]]
+ expect_equal(result, list(list(3L, 2L, 1L), list(6L, 5L, 4L)))
+ result <- collect(select(df, sort_array(df[[1]])))[[1]]
+ expect_equal(result, list(list(1L, 2L, 3L), list(4L, 5L, 6L)))
+
+ # Test that stats::lag is working
+ expect_equal(length(lag(ldeaths, 12)), 72)
})
#
test_that("column binary mathfunctions", {
@@ -1056,7 +1099,7 @@ test_that("group by, agg functions", {
gd3_local <- collect(agg(gd3, var(df8$age)))
expect_equal(162, gd3_local[gd3_local$name == "Justin",][1, 2])
- # make sure base:: or stats::sd, var are working
+ # Test stats::sd, stats::var are working
expect_true(abs(sd(1:2) - 0.7071068) < 1e-6)
expect_true(abs(var(1:5, 1:5) - 2.5) < 1e-6)
@@ -1108,6 +1151,9 @@ test_that("filter() on a DataFrame", {
expect_equal(count(filtered5), 1)
filtered6 <- where(df, df$age %in% c(19, 30))
expect_equal(count(filtered6), 2)
+
+ # Test stats::filter is working
+ #expect_true(is.ts(filter(1:100, rep(1, 3))))
})
test_that("join() and merge() on a DataFrame", {
@@ -1254,6 +1300,12 @@ test_that("unionAll(), rbind(), except(), and intersect() on a DataFrame", {
expect_is(unioned, "DataFrame")
expect_equal(count(intersected), 1)
expect_equal(first(intersected)$name, "Andy")
+
+ # Test base::rbind is working
+ expect_equal(length(rbind(1:4, c = 2, a = 10, 10, deparse.level = 0)), 16)
+
+ # Test base::intersect is working
+ expect_equal(length(intersect(1:20, 3:23)), 18)
})
test_that("withColumn() and withColumnRenamed()", {
@@ -1335,6 +1387,9 @@ test_that("describe() and summarize() on a DataFrame", {
stats2 <- summary(df)
expect_equal(collect(stats2)[4, "name"], "Andy")
expect_equal(collect(stats2)[5, "age"], "30")
+
+ # Test base::summary is working
+ expect_equal(length(summary(attenu, digits = 4)), 35)
})
test_that("dropna() and na.omit() on a DataFrame", {
@@ -1418,6 +1473,9 @@ test_that("dropna() and na.omit() on a DataFrame", {
expect_identical(expected, actual)
actual <- collect(na.omit(df, minNonNulls = 3, cols = c("name", "age", "height")))
expect_identical(expected, actual)
+
+ # Test stats::na.omit is working
+ expect_equal(nrow(na.omit(data.frame(x = c(0, 10, NA)))), 2)
})
test_that("fillna() on a DataFrame", {
@@ -1480,6 +1538,9 @@ test_that("cov() and corr() on a DataFrame", {
expect_true(abs(result - 1.0) < 1e-12)
result <- corr(df, "singles", "doubles", "pearson")
expect_true(abs(result - 1.0) < 1e-12)
+
+ # Test stats::cov is working
+ #expect_true(abs(max(cov(swiss)) - 1739.295) < 1e-3)
})
test_that("freqItems() on a DataFrame", {
diff --git a/R/pkg/inst/worker/daemon.R b/R/pkg/inst/worker/daemon.R
index 3584b418a71a..f55beac6c8c0 100644
--- a/R/pkg/inst/worker/daemon.R
+++ b/R/pkg/inst/worker/daemon.R
@@ -18,10 +18,11 @@
# Worker daemon
rLibDir <- Sys.getenv("SPARKR_RLIBDIR")
-script <- paste(rLibDir, "SparkR/worker/worker.R", sep = "/")
+dirs <- strsplit(rLibDir, ",")[[1]]
+script <- file.path(dirs[[1]], "SparkR", "worker", "worker.R")
# preload SparkR package, speedup worker
-.libPaths(c(rLibDir, .libPaths()))
+.libPaths(c(dirs, .libPaths()))
suppressPackageStartupMessages(library(SparkR))
port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT"))
diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R
index 0c3b0d1f4be2..3ae072beca11 100644
--- a/R/pkg/inst/worker/worker.R
+++ b/R/pkg/inst/worker/worker.R
@@ -35,10 +35,11 @@ bootTime <- currentTimeSecs()
bootElap <- elapsedSecs()
rLibDir <- Sys.getenv("SPARKR_RLIBDIR")
+dirs <- strsplit(rLibDir, ",")[[1]]
# Set libPaths to include SparkR package as loadNamespace needs this
# TODO: Figure out if we can avoid this by not loading any objects that require
# SparkR namespace
-.libPaths(c(rLibDir, .libPaths()))
+.libPaths(c(dirs, .libPaths()))
suppressPackageStartupMessages(library(SparkR))
port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT"))
diff --git a/core/pom.xml b/core/pom.xml
index 7e1205a076f2..37e3f168ab37 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -177,6 +177,10 @@
net.jpountz.lz4lz4
+
+ org.roaringbitmap
+ RoaringBitmap
+ commons-netcommons-net
diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function.java b/core/src/main/java/org/apache/spark/api/java/function/Function.java
index d00551bb0add..b9d9777a7565 100644
--- a/core/src/main/java/org/apache/spark/api/java/function/Function.java
+++ b/core/src/main/java/org/apache/spark/api/java/function/Function.java
@@ -25,5 +25,5 @@
* when mapping RDDs of other types.
*/
public interface Function extends Serializable {
- public R call(T1 v1) throws Exception;
+ R call(T1 v1) throws Exception;
}
diff --git a/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java b/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java
index 2935f9986a56..4f3f222e064b 100644
--- a/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java
+++ b/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java
@@ -21,7 +21,7 @@
import java.util.Iterator;
/**
- * Base interface for a map function used in GroupedDataset's map function.
+ * Base interface for a map function used in GroupedDataset's mapGroup function.
*/
public interface MapGroupFunction extends Serializable {
R call(K key, Iterator values) throws Exception;
diff --git a/core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java b/core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java
new file mode 100644
index 000000000000..6c576ab67845
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+import java.io.Serializable;
+
+/**
+ * A two-argument function that takes arguments of type T1 and T2 with no return value.
+ */
+public interface VoidFunction2 extends Serializable {
+ public void call(T1 v1, T2 v2) throws Exception;
+}
diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
index 5f743b28857b..d31eb449eb82 100644
--- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
+++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
@@ -215,6 +215,9 @@ public void showMemoryUsage() {
logger.info(
"{} bytes of memory were used by task {} but are not associated with specific consumers",
memoryNotAccountedFor, taskAttemptId);
+ logger.info(
+ "{} bytes of memory are used for execution and {} bytes of memory are used for storage",
+ memoryManager.executionMemoryUsed(), memoryManager.storageMemoryUsed());
}
}
diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index 04694dc54418..3387f9a4177c 100644
--- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -24,6 +24,7 @@
import java.util.LinkedList;
import com.google.common.annotations.VisibleForTesting;
+import com.google.common.io.Closeables;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -272,6 +273,7 @@ private void advanceToNextPage() {
}
}
try {
+ Closeables.close(reader, /* swallowIOException = */ false);
reader = spillWriters.getFirst().getReader(blockManager);
recordsInPage = -1;
} catch (IOException e) {
@@ -318,6 +320,11 @@ public Location next() {
try {
reader.loadNext();
} catch (IOException e) {
+ try {
+ reader.close();
+ } catch(IOException e2) {
+ logger.error("Error while closing spill reader", e2);
+ }
// Scala iterator does not handle exception
Platform.throwException(e);
}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
index 039e940a357e..dcb13e6581e5 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
@@ -20,8 +20,7 @@
import java.io.*;
import com.google.common.io.ByteStreams;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
+import com.google.common.io.Closeables;
import org.apache.spark.storage.BlockId;
import org.apache.spark.storage.BlockManager;
@@ -31,10 +30,8 @@
* Reads spill files written by {@link UnsafeSorterSpillWriter} (see that class for a description
* of the file format).
*/
-public final class UnsafeSorterSpillReader extends UnsafeSorterIterator {
- private static final Logger logger = LoggerFactory.getLogger(UnsafeSorterSpillReader.class);
+public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implements Closeable {
- private final File file;
private InputStream in;
private DataInputStream din;
@@ -52,11 +49,15 @@ public UnsafeSorterSpillReader(
File file,
BlockId blockId) throws IOException {
assert (file.length() > 0);
- this.file = file;
final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file));
- this.in = blockManager.wrapForCompression(blockId, bs);
- this.din = new DataInputStream(this.in);
- numRecordsRemaining = din.readInt();
+ try {
+ this.in = blockManager.wrapForCompression(blockId, bs);
+ this.din = new DataInputStream(this.in);
+ numRecordsRemaining = din.readInt();
+ } catch (IOException e) {
+ Closeables.close(bs, /* swallowIOException = */ true);
+ throw e;
+ }
}
@Override
@@ -75,12 +76,7 @@ public void loadNext() throws IOException {
ByteStreams.readFully(in, arr, 0, recordLength);
numRecordsRemaining--;
if (numRecordsRemaining == 0) {
- in.close();
- if (!file.delete() && file.exists()) {
- logger.warn("Unable to delete spill file {}", file.getPath());
- }
- in = null;
- din = null;
+ close();
}
}
@@ -103,4 +99,16 @@ public int getRecordLength() {
public long getKeyPrefix() {
return keyPrefix;
}
+
+ @Override
+ public void close() throws IOException {
+ if (in != null) {
+ try {
+ in.close();
+ } finally {
+ in = null;
+ din = null;
+ }
+ }
+ }
}
diff --git a/core/src/main/resources/org/apache/spark/ui/static/sorttable.js b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js
index dde6069000bc..a73d9a5cbc21 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/sorttable.js
+++ b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js
@@ -89,7 +89,7 @@ sorttable = {
// make it clickable to sort
headrow[i].sorttable_columnindex = i;
headrow[i].sorttable_tbody = table.tBodies[0];
- dean_addEvent(headrow[i],"click", function(e) {
+ dean_addEvent(headrow[i],"click", sorttable.innerSortFunction = function(e) {
if (this.className.search(/\bsorttable_sorted\b/) != -1) {
// if we're already sorted by this column, just
diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
index b93536e6536e..6419218f47c8 100644
--- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
+++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
@@ -509,6 +509,7 @@ private[spark] class ExecutorAllocationManager(
private def onExecutorBusy(executorId: String): Unit = synchronized {
logDebug(s"Clearing idle timer for $executorId because it is now running a task")
removeTimes.remove(executorId)
+ executorsPendingToRemove.remove(executorId)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 4bbd0b038c00..af4456c05b0a 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -581,6 +581,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
// Post init
_taskScheduler.postStartHook()
+ _env.metricsSystem.registerSource(_dagScheduler.metricsSource)
_env.metricsSystem.registerSource(new BlockManagerSource(_env.blockManager))
_executorAllocationManager.foreach { e =>
_env.metricsSystem.registerSource(e.executorAllocationManagerSource)
@@ -1461,7 +1462,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
override def killExecutors(executorIds: Seq[String]): Boolean = {
schedulerBackend match {
case b: CoarseGrainedSchedulerBackend =>
- b.killExecutors(executorIds)
+ b.killExecutors(executorIds, replace = false, force = true)
case _ =>
logWarning("Killing executors is only supported in coarse-grained mode")
false
@@ -1499,7 +1500,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
private[spark] def killAndReplaceExecutor(executorId: String): Boolean = {
schedulerBackend match {
case b: CoarseGrainedSchedulerBackend =>
- b.killExecutors(Seq(executorId), replace = true)
+ b.killExecutors(Seq(executorId), replace = true, force = true)
case _ =>
logWarning("Killing executors is only supported in coarse-grained mode")
false
@@ -2710,7 +2711,7 @@ object SparkContext extends Logging {
case mesosUrl @ MESOS_REGEX(_) =>
MesosNativeLibrary.load()
val scheduler = new TaskSchedulerImpl(sc)
- val coarseGrained = sc.conf.getBoolean("spark.mesos.coarse", false)
+ val coarseGrained = sc.conf.getBoolean("spark.mesos.coarse", defaultValue = true)
val url = mesosUrl.stripPrefix("mesos://") // strip scheme from raw Mesos URLs
val backend = if (coarseGrained) {
new CoarseMesosSchedulerBackend(scheduler, sc, url, sc.env.securityManager)
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 4474a83bedbd..88df27f733f2 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -258,8 +258,15 @@ object SparkEnv extends Logging {
if (rpcEnv.isInstanceOf[AkkaRpcEnv]) {
rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem
} else {
+ val actorSystemPort = if (port == 0) 0 else rpcEnv.address.port + 1
// Create a ActorSystem for legacy codes
- AkkaUtils.createActorSystem(actorSystemName, hostname, port, conf, securityManager)._1
+ AkkaUtils.createActorSystem(
+ actorSystemName + "ActorSystem",
+ hostname,
+ actorSystemPort,
+ conf,
+ securityManager
+ )._1
}
// Figure out which port Akka actually bound to in case the original port is 0 or occupied.
diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala
index acfe751f6c74..43c89b258f2f 100644
--- a/core/src/main/scala/org/apache/spark/TestUtils.scala
+++ b/core/src/main/scala/org/apache/spark/TestUtils.scala
@@ -20,6 +20,7 @@ package org.apache.spark
import java.io.{ByteArrayInputStream, File, FileInputStream, FileOutputStream}
import java.net.{URI, URL}
import java.nio.charset.StandardCharsets
+import java.nio.file.Paths
import java.util.Arrays
import java.util.jar.{JarEntry, JarOutputStream}
@@ -83,15 +84,15 @@ private[spark] object TestUtils {
}
/**
- * Create a jar file that contains this set of files. All files will be located at the root
- * of the jar.
+ * Create a jar file that contains this set of files. All files will be located in the specified
+ * directory or at the root of the jar.
*/
- def createJar(files: Seq[File], jarFile: File): URL = {
+ def createJar(files: Seq[File], jarFile: File, directoryPrefix: Option[String] = None): URL = {
val jarFileStream = new FileOutputStream(jarFile)
val jarStream = new JarOutputStream(jarFileStream, new java.util.jar.Manifest())
for (file <- files) {
- val jarEntry = new JarEntry(file.getName)
+ val jarEntry = new JarEntry(Paths.get(directoryPrefix.getOrElse(""), file.getName).toString)
jarStream.putNextEntry(jarEntry)
val in = new FileInputStream(file)
@@ -123,7 +124,7 @@ private[spark] object TestUtils {
classpathUrls: Seq[URL]): File = {
val compiler = ToolProvider.getSystemJavaCompiler
- // Calling this outputs a class file in pwd. It's easier to just rename the file than
+ // Calling this outputs a class file in pwd. It's easier to just rename the files than
// build a custom FileManager that controls the output location.
val options = if (classpathUrls.nonEmpty) {
Seq("-classpath", classpathUrls.map { _.getFile }.mkString(File.pathSeparator))
diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
index b7e72d4d0ed0..8b3be0da2c8c 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
@@ -113,6 +113,7 @@ private[spark] object RBackend extends Logging {
val dos = new DataOutputStream(new FileOutputStream(f))
dos.writeInt(boundPort)
dos.writeInt(listenPort)
+ SerDe.writeString(dos, RUtils.rPackages.getOrElse(""))
dos.close()
f.renameTo(new File(path))
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
index 6b418e908cb5..7509b3d3f44b 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
@@ -400,14 +400,14 @@ private[r] object RRDD {
val rOptions = "--vanilla"
val rLibDir = RUtils.sparkRPackagePath(isDriver = false)
- val rExecScript = rLibDir + "/SparkR/worker/" + script
+ val rExecScript = rLibDir(0) + "/SparkR/worker/" + script
val pb = new ProcessBuilder(Arrays.asList(rCommand, rOptions, rExecScript))
// Unset the R_TESTS environment variable for workers.
// This is set by R CMD check as startup.Rs
// (http://svn.r-project.org/R/trunk/src/library/tools/R/testing.R)
// and confuses worker script which tries to load a non-existent file
pb.environment().put("R_TESTS", "")
- pb.environment().put("SPARKR_RLIBDIR", rLibDir)
+ pb.environment().put("SPARKR_RLIBDIR", rLibDir.mkString(","))
pb.environment().put("SPARKR_WORKER_PORT", port.toString)
pb.redirectErrorStream(true) // redirect stderr into stdout
val proc = pb.start()
diff --git a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala
index fd5646b5b637..16157414fd12 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala
@@ -23,6 +23,10 @@ import java.util.Arrays
import org.apache.spark.{SparkEnv, SparkException}
private[spark] object RUtils {
+ // Local path where R binary packages built from R source code contained in the spark
+ // packages specified with "--packages" or "--jars" command line option reside.
+ var rPackages: Option[String] = None
+
/**
* Get the SparkR package path in the local spark distribution.
*/
@@ -34,11 +38,15 @@ private[spark] object RUtils {
}
/**
- * Get the SparkR package path in various deployment modes.
+ * Get the list of paths for R packages in various deployment modes, of which the first
+ * path is for the SparkR package itself. The second path is for R packages built as
+ * part of Spark Packages, if any exist. Spark Packages can be provided through the
+ * "--packages" or "--jars" command line options.
+ *
* This assumes that Spark properties `spark.master` and `spark.submit.deployMode`
* and environment variable `SPARK_HOME` are set.
*/
- def sparkRPackagePath(isDriver: Boolean): String = {
+ def sparkRPackagePath(isDriver: Boolean): Seq[String] = {
val (master, deployMode) =
if (isDriver) {
(sys.props("spark.master"), sys.props("spark.submit.deployMode"))
@@ -51,15 +59,30 @@ private[spark] object RUtils {
val isYarnClient = master != null && master.contains("yarn") && deployMode == "client"
// In YARN mode, the SparkR package is distributed as an archive symbolically
- // linked to the "sparkr" file in the current directory. Note that this does not apply
- // to the driver in client mode because it is run outside of the cluster.
+ // linked to the "sparkr" file in the current directory and additional R packages
+ // are distributed as an archive symbolically linked to the "rpkg" file in the
+ // current directory.
+ //
+ // Note that this does not apply to the driver in client mode because it is run
+ // outside of the cluster.
if (isYarnCluster || (isYarnClient && !isDriver)) {
- new File("sparkr").getAbsolutePath
+ val sparkRPkgPath = new File("sparkr").getAbsolutePath
+ val rPkgPath = new File("rpkg")
+ if (rPkgPath.exists()) {
+ Seq(sparkRPkgPath, rPkgPath.getAbsolutePath)
+ } else {
+ Seq(sparkRPkgPath)
+ }
} else {
// Otherwise, assume the package is local
// TODO: support this for Mesos
- localSparkRPackagePath.getOrElse {
- throw new SparkException("SPARK_HOME not set. Can't locate SparkR package.")
+ val sparkRPkgPath = localSparkRPackagePath.getOrElse {
+ throw new SparkException("SPARK_HOME not set. Can't locate SparkR package.")
+ }
+ if (!rPackages.isEmpty) {
+ Seq(sparkRPkgPath, rPackages.get)
+ } else {
+ Seq(sparkRPkgPath)
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala
index a039d543c35e..e8a1e35c3fc4 100644
--- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala
@@ -45,7 +45,8 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana
private val port = sparkConf.getInt("spark.shuffle.service.port", 7337)
private val useSasl: Boolean = securityManager.isAuthenticationEnabled()
- private val transportConf = SparkTransportConf.fromSparkConf(sparkConf, numUsableCores = 0)
+ private val transportConf =
+ SparkTransportConf.fromSparkConf(sparkConf, "shuffle", numUsableCores = 0)
private val blockHandler = newShuffleBlockHandler(transportConf)
private val transportContext: TransportContext =
new TransportContext(transportConf, blockHandler, true)
diff --git a/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala b/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala
index 7d160b6790ea..d46dc87a92c9 100644
--- a/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala
@@ -100,20 +100,29 @@ private[deploy] object RPackageUtils extends Logging {
* Runs the standard R package installation code to build the R package from source.
* Multiple runs don't cause problems.
*/
- private def rPackageBuilder(dir: File, printStream: PrintStream, verbose: Boolean): Boolean = {
+ private def rPackageBuilder(
+ dir: File,
+ printStream: PrintStream,
+ verbose: Boolean,
+ libDir: String): Boolean = {
// this code should be always running on the driver.
- val pathToSparkR = RUtils.localSparkRPackagePath.getOrElse(
- throw new SparkException("SPARK_HOME not set. Can't locate SparkR package."))
val pathToPkg = Seq(dir, "R", "pkg").mkString(File.separator)
- val installCmd = baseInstallCmd ++ Seq(pathToSparkR, pathToPkg)
+ val installCmd = baseInstallCmd ++ Seq(libDir, pathToPkg)
if (verbose) {
print(s"Building R package with the command: $installCmd", printStream)
}
try {
val builder = new ProcessBuilder(installCmd.asJava)
builder.redirectErrorStream(true)
+
+ // Put the SparkR package directory into R library search paths in case this R package
+ // may depend on SparkR.
val env = builder.environment()
- env.clear()
+ val rPackageDir = RUtils.sparkRPackagePath(isDriver = true)
+ env.put("SPARKR_PACKAGE_DIR", rPackageDir.mkString(","))
+ env.put("R_PROFILE_USER",
+ Seq(rPackageDir(0), "SparkR", "profile", "general.R").mkString(File.separator))
+
val process = builder.start()
new RedirectThread(process.getInputStream, printStream, "redirect R packaging").start()
process.waitFor() == 0
@@ -170,8 +179,11 @@ private[deploy] object RPackageUtils extends Logging {
if (checkManifestForR(jar)) {
print(s"$file contains R source code. Now installing package.", printStream, Level.INFO)
val rSource = extractRFolder(jar, printStream, verbose)
+ if (RUtils.rPackages.isEmpty) {
+ RUtils.rPackages = Some(Utils.createTempDir().getAbsolutePath)
+ }
try {
- if (!rPackageBuilder(rSource, printStream, verbose)) {
+ if (!rPackageBuilder(rSource, printStream, verbose, RUtils.rPackages.get)) {
print(s"ERROR: Failed to build R package in $file.", printStream)
print(RJarDoc, printStream)
}
@@ -208,7 +220,7 @@ private[deploy] object RPackageUtils extends Logging {
}
}
- /** Zips all the libraries found with SparkR in the R/lib directory for distribution with Yarn. */
+ /** Zips all the R libraries built for distribution to the cluster. */
private[deploy] def zipRLibraries(dir: File, name: String): File = {
val filesToBundle = listFilesRecursively(dir, Seq(".zip"))
// create a zip file from scratch, do not append to existing file.
diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala
index ed183cf16a9c..661f7317c674 100644
--- a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala
@@ -82,9 +82,10 @@ object RRunner {
val env = builder.environment()
env.put("EXISTING_SPARKR_BACKEND_PORT", sparkRBackendPort.toString)
val rPackageDir = RUtils.sparkRPackagePath(isDriver = true)
- env.put("SPARKR_PACKAGE_DIR", rPackageDir)
+ // Put the R package directories into an env variable of comma-separated paths
+ env.put("SPARKR_PACKAGE_DIR", rPackageDir.mkString(","))
env.put("R_PROFILE_USER",
- Seq(rPackageDir, "SparkR", "profile", "general.R").mkString(File.separator))
+ Seq(rPackageDir(0), "SparkR", "profile", "general.R").mkString(File.separator))
builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize
val process = builder.start()
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
index d606b80c03c9..59e90564b351 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
@@ -92,10 +92,15 @@ class SparkHadoopUtil extends Logging {
// Explicitly check for S3 environment variables
if (System.getenv("AWS_ACCESS_KEY_ID") != null &&
System.getenv("AWS_SECRET_ACCESS_KEY") != null) {
- hadoopConf.set("fs.s3.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID"))
- hadoopConf.set("fs.s3n.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID"))
- hadoopConf.set("fs.s3.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY"))
- hadoopConf.set("fs.s3n.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY"))
+ val keyId = System.getenv("AWS_ACCESS_KEY_ID")
+ val accessKey = System.getenv("AWS_SECRET_ACCESS_KEY")
+
+ hadoopConf.set("fs.s3.awsAccessKeyId", keyId)
+ hadoopConf.set("fs.s3n.awsAccessKeyId", keyId)
+ hadoopConf.set("fs.s3a.access.key", keyId)
+ hadoopConf.set("fs.s3.awsSecretAccessKey", accessKey)
+ hadoopConf.set("fs.s3n.awsSecretAccessKey", accessKey)
+ hadoopConf.set("fs.s3a.secret.key", accessKey)
}
// Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar"
conf.getAll.foreach { case (key, value) =>
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index 84ae122f4437..2e912b59afdb 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -39,7 +39,7 @@ import org.apache.ivy.plugins.matcher.GlobPatternMatcher
import org.apache.ivy.plugins.repository.file.FileRepository
import org.apache.ivy.plugins.resolver.{FileSystemResolver, ChainResolver, IBiblioResolver}
-import org.apache.spark.{SparkUserAppException, SPARK_VERSION}
+import org.apache.spark.{SparkException, SparkUserAppException, SPARK_VERSION}
import org.apache.spark.api.r.RUtils
import org.apache.spark.deploy.rest._
import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils}
@@ -83,6 +83,7 @@ object SparkSubmit {
private val PYSPARK_SHELL = "pyspark-shell"
private val SPARKR_SHELL = "sparkr-shell"
private val SPARKR_PACKAGE_ARCHIVE = "sparkr.zip"
+ private val R_PACKAGE_ARCHIVE = "rpkg.zip"
private val CLASS_NOT_FOUND_EXIT_STATUS = 101
@@ -362,22 +363,46 @@ object SparkSubmit {
}
}
- // In YARN mode for an R app, add the SparkR package archive to archives
- // that can be distributed with the job
+ // In YARN mode for an R app, add the SparkR package archive and the R package
+ // archive containing all of the built R libraries to archives so that they can
+ // be distributed with the job
if (args.isR && clusterManager == YARN) {
- val rPackagePath = RUtils.localSparkRPackagePath
- if (rPackagePath.isEmpty) {
+ val sparkRPackagePath = RUtils.localSparkRPackagePath
+ if (sparkRPackagePath.isEmpty) {
printErrorAndExit("SPARK_HOME does not exist for R application in YARN mode.")
}
- val rPackageFile =
- RPackageUtils.zipRLibraries(new File(rPackagePath.get), SPARKR_PACKAGE_ARCHIVE)
- if (!rPackageFile.exists()) {
+ val sparkRPackageFile = new File(sparkRPackagePath.get, SPARKR_PACKAGE_ARCHIVE)
+ if (!sparkRPackageFile.exists()) {
printErrorAndExit(s"$SPARKR_PACKAGE_ARCHIVE does not exist for R application in YARN mode.")
}
- val localURI = Utils.resolveURI(rPackageFile.getAbsolutePath)
+ val sparkRPackageURI = Utils.resolveURI(sparkRPackageFile.getAbsolutePath).toString
+ // Distribute the SparkR package.
// Assigns a symbol link name "sparkr" to the shipped package.
- args.archives = mergeFileLists(args.archives, localURI.toString + "#sparkr")
+ args.archives = mergeFileLists(args.archives, sparkRPackageURI + "#sparkr")
+
+ // Distribute the R package archive containing all the built R packages.
+ if (!RUtils.rPackages.isEmpty) {
+ val rPackageFile =
+ RPackageUtils.zipRLibraries(new File(RUtils.rPackages.get), R_PACKAGE_ARCHIVE)
+ if (!rPackageFile.exists()) {
+ printErrorAndExit("Failed to zip all the built R packages.")
+ }
+
+ val rPackageURI = Utils.resolveURI(rPackageFile.getAbsolutePath).toString
+ // Assigns a symbol link name "rpkg" to the shipped package.
+ args.archives = mergeFileLists(args.archives, rPackageURI + "#rpkg")
+ }
+ }
+
+ // TODO: Support distributing R packages with standalone cluster
+ if (args.isR && clusterManager == STANDALONE && !RUtils.rPackages.isEmpty) {
+ printErrorAndExit("Distributing R packages with standalone cluster is not supported.")
+ }
+
+ // TODO: Support SparkR with mesos cluster
+ if (args.isR && clusterManager == MESOS) {
+ printErrorAndExit("SparkR is not supported for Mesos cluster.")
}
// If we're running a R app, set the main class to our specific R runner
@@ -521,8 +546,19 @@ object SparkSubmit {
sysProps.put("spark.yarn.isPython", "true")
}
if (args.principal != null) {
- require(args.keytab != null, "Keytab must be specified when the keytab is specified")
- UserGroupInformation.loginUserFromKeytab(args.principal, args.keytab)
+ require(args.keytab != null, "Keytab must be specified when principal is specified")
+ if (!new File(args.keytab).exists()) {
+ throw new SparkException(s"Keytab file: ${args.keytab} does not exist")
+ } else {
+ // Add keytab and principal configurations in sysProps to make them available
+ // for later use; e.g. in spark sql, the isolated class loader used to talk
+ // to HiveMetastore will use these settings. They will be set as Java system
+ // properties and then loaded by SparkConf
+ sysProps.put("spark.yarn.keytab", args.keytab)
+ sysProps.put("spark.yarn.principal", args.principal)
+
+ UserGroupInformation.loginUserFromKeytab(args.principal, args.keytab)
+ }
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala
index 957a928bc402..f0dd667ea1b2 100644
--- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala
@@ -19,16 +19,19 @@ package org.apache.spark.deploy.rest
import java.io.{DataOutputStream, FileNotFoundException}
import java.net.{ConnectException, HttpURLConnection, SocketException, URL}
+import java.util.concurrent.TimeoutException
import javax.servlet.http.HttpServletResponse
import scala.collection.mutable
+import scala.concurrent.duration._
+import scala.concurrent.{Await, Future}
import scala.io.Source
import com.fasterxml.jackson.core.JsonProcessingException
import com.google.common.base.Charsets
-import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion}
import org.apache.spark.util.Utils
+import org.apache.spark.{Logging, SPARK_VERSION => sparkVersion, SparkConf}
/**
* A client that submits applications to a [[RestSubmissionServer]].
@@ -225,7 +228,8 @@ private[spark] class RestSubmissionClient(master: String) extends Logging {
* Exposed for testing.
*/
private[rest] def readResponse(connection: HttpURLConnection): SubmitRestProtocolResponse = {
- try {
+ import scala.concurrent.ExecutionContext.Implicits.global
+ val responseFuture = Future {
val dataStream =
if (connection.getResponseCode == HttpServletResponse.SC_OK) {
connection.getInputStream
@@ -251,11 +255,15 @@ private[spark] class RestSubmissionClient(master: String) extends Logging {
throw new SubmitRestProtocolException(
s"Message received from server was not a response:\n${unexpected.toJson}")
}
- } catch {
+ }
+
+ try { Await.result(responseFuture, 10.seconds) } catch {
case unreachable @ (_: FileNotFoundException | _: SocketException) =>
throw new SubmitRestConnectionException("Unable to connect to server", unreachable)
case malformed @ (_: JsonProcessingException | _: SubmitRestProtocolException) =>
throw new SubmitRestProtocolException("Malformed response received from server", malformed)
+ case timeout: TimeoutException =>
+ throw new SubmitRestConnectionException("No response from server", timeout)
}
}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
index 70a42f9045e6..b0694e3c6c8a 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
@@ -41,7 +41,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
// TODO: Don't use Java serialization, use a more cross-version compatible serialization format.
private val serializer = new JavaSerializer(conf)
private val authEnabled = securityManager.isAuthenticationEnabled()
- private val transportConf = SparkTransportConf.fromSparkConf(conf, numCores)
+ private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numCores)
private[this] var transportContext: TransportContext = _
private[this] var server: TransportServer = _
diff --git a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala
index cef203006d68..84833f59d7af 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala
@@ -40,23 +40,23 @@ object SparkTransportConf {
/**
* Utility for creating a [[TransportConf]] from a [[SparkConf]].
+ * @param _conf the [[SparkConf]]
+ * @param module the module name
* @param numUsableCores if nonzero, this will restrict the server and client threads to only
* use the given number of cores, rather than all of the machine's cores.
* This restriction will only occur if these properties are not already set.
*/
- def fromSparkConf(_conf: SparkConf, numUsableCores: Int = 0): TransportConf = {
+ def fromSparkConf(_conf: SparkConf, module: String, numUsableCores: Int = 0): TransportConf = {
val conf = _conf.clone
// Specify thread configuration based on our JVM's allocation of cores (rather than necessarily
// assuming we have all the machine's cores).
// NB: Only set if serverThreads/clientThreads not already set.
val numThreads = defaultNumThreads(numUsableCores)
- conf.set("spark.shuffle.io.serverThreads",
- conf.get("spark.shuffle.io.serverThreads", numThreads.toString))
- conf.set("spark.shuffle.io.clientThreads",
- conf.get("spark.shuffle.io.clientThreads", numThreads.toString))
+ conf.setIfMissing(s"spark.$module.io.serverThreads", numThreads.toString)
+ conf.setIfMissing(s"spark.$module.io.clientThreads", numThreads.toString)
- new TransportConf(new ConfigProvider {
+ new TransportConf(module, new ConfigProvider {
override def get(name: String): String = conf.get(name)
})
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
index ca1eb1f4e4a9..d5e853613b05 100644
--- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
@@ -66,6 +66,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
*/
def takeAsync(num: Int): FutureAction[Seq[T]] = self.withScope {
val f = new ComplexFutureAction[Seq[T]]
+ val callSite = self.context.getCallSite
f.run {
// This is a blocking action so we should use "AsyncRDDActions.futureExecutionContext" which
@@ -73,6 +74,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
val results = new ArrayBuffer[T](num)
val totalParts = self.partitions.length
var partsScanned = 0
+ self.context.setCallSite(callSite)
while (results.size < num && partsScanned < totalParts) {
// The number of partitions to try in this iteration. It is ok for this number to be
// greater than totalParts because we actually cap it at totalParts in runJob.
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index 0453614f6a1d..f37c95bedc0a 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -213,6 +213,12 @@ class HadoopRDD[K, V](
val inputMetrics = context.taskMetrics.getInputMetricsForReadMethod(DataReadMethod.Hadoop)
+ // Sets the thread local variable for the file's name
+ split.inputSplit.value match {
+ case fs: FileSplit => SqlNewHadoopRDDState.setInputFileName(fs.getPath.toString)
+ case _ => SqlNewHadoopRDDState.unsetInputFileName()
+ }
+
// Find a function that will return the FileSystem bytes read by this thread. Do this before
// creating RecordReader, because RecordReader's constructor might read some bytes
val bytesReadCallback = inputMetrics.bytesReadCallback.orElse {
@@ -250,6 +256,7 @@ class HadoopRDD[K, V](
override def close() {
if (reader != null) {
+ SqlNewHadoopRDDState.unsetInputFileName()
// Close the reader and release it. Note: it's very important that we don't close the
// reader more than once, since that exposes us to MAPREDUCE-5918 when running against
// Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic
diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala
index d6a37e8cc5da..0c6ddda52cee 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala
@@ -65,7 +65,7 @@ class PartitionPruningRDD[T: ClassTag](
}
override protected def getPartitions: Array[Partition] =
- getDependencies.head.asInstanceOf[PruneDependency[T]].partitions
+ dependencies.head.asInstanceOf[PruneDependency[T]].partitions
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 800ef53cbef0..2aeb5eeaad32 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -705,6 +705,24 @@ abstract class RDD[T: ClassTag](
preservesPartitioning)
}
+ /**
+ * [performance] Spark's internal mapPartitions method which skips closure cleaning. It is a
+ * performance API to be used carefully only if we are sure that the RDD elements are
+ * serializable and don't require closure cleaning.
+ *
+ * @param preservesPartitioning indicates whether the input function preserves the partitioner,
+ * which should be `false` unless this is a pair RDD and the input function doesn't modify
+ * the keys.
+ */
+ private[spark] def mapPartitionsInternal[U: ClassTag](
+ f: Iterator[T] => Iterator[U],
+ preservesPartitioning: Boolean = false): RDD[U] = withScope {
+ new MapPartitionsRDD(
+ this,
+ (context: TaskContext, index: Int, iter: Iterator[T]) => f(iter),
+ preservesPartitioning)
+ }
+
/**
* Return a new RDD by applying a function to each partition of this RDD, while tracking the index
* of the original partition.
diff --git a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDDState.scala b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDDState.scala
new file mode 100644
index 000000000000..3f15fff79366
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDDState.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * State for SqlNewHadoopRDD objects. This is split this way because of the package splits.
+ * TODO: Move/Combine this with org.apache.spark.sql.datasources.SqlNewHadoopRDD
+ */
+private[spark] object SqlNewHadoopRDDState {
+ /**
+ * The thread variable for the name of the current file being read. This is used by
+ * the InputFileName function in Spark SQL.
+ */
+ private[this] val inputFileName: ThreadLocal[UTF8String] = new ThreadLocal[UTF8String] {
+ override protected def initialValue(): UTF8String = UTF8String.fromString("")
+ }
+
+ def getInputFileName(): UTF8String = inputFileName.get()
+
+ private[spark] def setInputFileName(file: String) = inputFileName.set(UTF8String.fromString(file))
+
+ private[spark] def unsetInputFileName(): Unit = inputFileName.remove()
+
+}
diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
index 3fad595a0d0b..059a7e10ec12 100644
--- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
@@ -263,7 +263,7 @@ private[akka] class ErrorMonitor extends Actor with ActorLogReceive with Logging
}
override def receiveWithLogging: Actor.Receive = {
- case Error(cause: Throwable, _, _, message: String) => logError(message, cause)
+ case Error(cause: Throwable, _, _, message: String) => logDebug(message, cause)
}
}
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
index 09093819bb22..3ce359868039 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
@@ -22,16 +22,13 @@ import java.net.{InetSocketAddress, URI}
import java.nio.ByteBuffer
import java.util.concurrent._
import java.util.concurrent.atomic.AtomicBoolean
-import javax.annotation.Nullable;
-import javax.annotation.concurrent.GuardedBy
+import javax.annotation.Nullable
-import scala.collection.mutable
import scala.concurrent.{Future, Promise}
import scala.reflect.ClassTag
import scala.util.{DynamicVariable, Failure, Success}
import scala.util.control.NonFatal
-import com.google.common.base.Preconditions
import org.apache.spark.{Logging, SecurityManager, SparkConf}
import org.apache.spark.network.TransportContext
import org.apache.spark.network.client._
@@ -49,7 +46,8 @@ private[netty] class NettyRpcEnv(
securityManager: SecurityManager) extends RpcEnv(conf) with Logging {
private val transportConf = SparkTransportConf.fromSparkConf(
- conf.clone.set("spark.shuffle.io.numConnectionsPerPeer", "1"),
+ conf.clone.set("spark.rpc.io.numConnectionsPerPeer", "1"),
+ "rpc",
conf.getInt("spark.rpc.io.threads", 0))
private val dispatcher: Dispatcher = new Dispatcher(this)
@@ -104,7 +102,7 @@ private[netty] class NettyRpcEnv(
} else {
java.util.Collections.emptyList()
}
- server = transportContext.createServer(port, bootstraps)
+ server = transportContext.createServer(host, port, bootstraps)
dispatcher.registerRpcEndpoint(
RpcEndpointVerifier.NAME, new RpcEndpointVerifier(this, dispatcher))
}
@@ -339,10 +337,10 @@ private[netty] class NettyRpcEnvFactory extends RpcEnvFactory with Logging {
if (!config.clientMode) {
val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort =>
nettyEnv.startServer(actualPort)
- (nettyEnv, actualPort)
+ (nettyEnv, nettyEnv.address.port)
}
try {
- Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, "NettyRpcEnv")._1
+ Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, config.name)._1
} catch {
case NonFatal(e) =>
nettyEnv.shutdown()
@@ -372,7 +370,6 @@ private[netty] class NettyRpcEnvFactory extends RpcEnvFactory with Logging {
* @param conf Spark configuration.
* @param endpointAddress The address where the endpoint is listening.
* @param nettyEnv The RpcEnv associated with this ref.
- * @param local Whether the referenced endpoint lives in the same process.
*/
private[netty] class NettyRpcEndpointRef(
@transient private val conf: SparkConf,
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 4a9518fff4e7..ae725b467d8c 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -130,7 +130,7 @@ class DAGScheduler(
def this(sc: SparkContext) = this(sc, sc.taskScheduler)
- private[scheduler] val metricsSource: DAGSchedulerSource = new DAGSchedulerSource(this)
+ private[spark] val metricsSource: DAGSchedulerSource = new DAGSchedulerSource(this)
private[scheduler] val nextJobId = new AtomicInteger(0)
private[scheduler] def numTotalJobs: Int = nextJobId.get()
@@ -1580,8 +1580,6 @@ class DAGScheduler(
taskScheduler.stop()
}
- // Start the event thread and register the metrics source at the end of the constructor
- env.metricsSystem.registerSource(metricsSource)
eventProcessLoop.start()
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala
index 47a5cbff4930..7e1197d74280 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala
@@ -40,6 +40,8 @@ private[spark] object ExecutorExited {
}
}
+private[spark] object ExecutorKilled extends ExecutorLossReason("Executor killed by driver.")
+
/**
* A loss reason that means we don't yet know why the executor exited.
*
diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
index 180c8d1827e1..b2e9a97129f0 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
@@ -19,8 +19,9 @@ package org.apache.spark.scheduler
import java.io.{Externalizable, ObjectInput, ObjectOutput}
+import org.roaringbitmap.RoaringBitmap
+
import org.apache.spark.storage.BlockManagerId
-import org.apache.spark.util.collection.BitSet
import org.apache.spark.util.Utils
/**
@@ -121,8 +122,7 @@ private[spark] class CompressedMapStatus(
/**
* A [[MapStatus]] implementation that only stores the average size of non-empty blocks,
- * plus a bitmap for tracking which blocks are empty. During serialization, this bitmap
- * is compressed.
+ * plus a bitmap for tracking which blocks are empty.
*
* @param loc location where the task is being executed
* @param numNonEmptyBlocks the number of non-empty blocks
@@ -132,7 +132,7 @@ private[spark] class CompressedMapStatus(
private[spark] class HighlyCompressedMapStatus private (
private[this] var loc: BlockManagerId,
private[this] var numNonEmptyBlocks: Int,
- private[this] var emptyBlocks: BitSet,
+ private[this] var emptyBlocks: RoaringBitmap,
private[this] var avgSize: Long)
extends MapStatus with Externalizable {
@@ -145,7 +145,7 @@ private[spark] class HighlyCompressedMapStatus private (
override def location: BlockManagerId = loc
override def getSizeForBlock(reduceId: Int): Long = {
- if (emptyBlocks.get(reduceId)) {
+ if (emptyBlocks.contains(reduceId)) {
0
} else {
avgSize
@@ -160,7 +160,7 @@ private[spark] class HighlyCompressedMapStatus private (
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
loc = BlockManagerId(in)
- emptyBlocks = new BitSet
+ emptyBlocks = new RoaringBitmap()
emptyBlocks.readExternal(in)
avgSize = in.readLong()
}
@@ -176,15 +176,15 @@ private[spark] object HighlyCompressedMapStatus {
// From a compression standpoint, it shouldn't matter whether we track empty or non-empty
// blocks. From a performance standpoint, we benefit from tracking empty blocks because
// we expect that there will be far fewer of them, so we will perform fewer bitmap insertions.
+ val emptyBlocks = new RoaringBitmap()
val totalNumBlocks = uncompressedSizes.length
- val emptyBlocks = new BitSet(totalNumBlocks)
while (i < totalNumBlocks) {
var size = uncompressedSizes(i)
if (size > 0) {
numNonEmptyBlocks += 1
totalSize += size
} else {
- emptyBlocks.set(i)
+ emptyBlocks.add(i)
}
i += 1
}
@@ -193,6 +193,8 @@ private[spark] object HighlyCompressedMapStatus {
} else {
0
}
+ emptyBlocks.trim()
+ emptyBlocks.runOptimize()
new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize)
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
index 46a6f6537e2e..f4965994d827 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
@@ -103,16 +103,16 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
try {
getTaskResultExecutor.execute(new Runnable {
override def run(): Unit = Utils.logUncaughtExceptions {
+ val loader = Utils.getContextOrSparkClassLoader
try {
if (serializedData != null && serializedData.limit() > 0) {
reason = serializer.get().deserialize[TaskEndReason](
- serializedData, Utils.getSparkClassLoader)
+ serializedData, loader)
}
} catch {
case cnd: ClassNotFoundException =>
// Log an error but keep going here -- the task failed, so not catastrophic
// if we can't deserialize the reason.
- val loader = Utils.getContextOrSparkClassLoader
logError(
"Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader)
case ex: Exception => {}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index 43d7d80b7aae..bdf19f9f277d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -87,8 +87,8 @@ private[spark] class TaskSchedulerImpl(
// Incrementing task IDs
val nextTaskId = new AtomicLong(0)
- // Which executor IDs we have executors on
- val activeExecutorIds = new HashSet[String]
+ // Number of tasks running on each executor
+ private val executorIdToTaskCount = new HashMap[String, Int]
// The set of executors we have on each host; this is used to compute hostsAlive, which
// in turn is used to decide when we can attain data locality on a given host
@@ -254,6 +254,7 @@ private[spark] class TaskSchedulerImpl(
val tid = task.taskId
taskIdToTaskSetManager(tid) = taskSet
taskIdToExecutorId(tid) = execId
+ executorIdToTaskCount(execId) += 1
executorsByHost(host) += execId
availableCpus(i) -= CPUS_PER_TASK
assert(availableCpus(i) >= 0)
@@ -282,7 +283,7 @@ private[spark] class TaskSchedulerImpl(
var newExecAvail = false
for (o <- offers) {
executorIdToHost(o.executorId) = o.host
- activeExecutorIds += o.executorId
+ executorIdToTaskCount.getOrElseUpdate(o.executorId, 0)
if (!executorsByHost.contains(o.host)) {
executorsByHost(o.host) = new HashSet[String]()
executorAdded(o.executorId, o.host)
@@ -331,7 +332,8 @@ private[spark] class TaskSchedulerImpl(
if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) {
// We lost this entire executor, so remember that it's gone
val execId = taskIdToExecutorId(tid)
- if (activeExecutorIds.contains(execId)) {
+
+ if (executorIdToTaskCount.contains(execId)) {
removeExecutor(execId,
SlaveLost(s"Task $tid was lost, so marking the executor as lost as well."))
failedExecutor = Some(execId)
@@ -341,7 +343,11 @@ private[spark] class TaskSchedulerImpl(
case Some(taskSet) =>
if (TaskState.isFinished(state)) {
taskIdToTaskSetManager.remove(tid)
- taskIdToExecutorId.remove(tid)
+ taskIdToExecutorId.remove(tid).foreach { execId =>
+ if (executorIdToTaskCount.contains(execId)) {
+ executorIdToTaskCount(execId) -= 1
+ }
+ }
}
if (state == TaskState.FINISHED) {
taskSet.removeRunningTask(tid)
@@ -462,26 +468,27 @@ private[spark] class TaskSchedulerImpl(
var failedExecutor: Option[String] = None
synchronized {
- if (activeExecutorIds.contains(executorId)) {
+ if (executorIdToTaskCount.contains(executorId)) {
val hostPort = executorIdToHost(executorId)
- logError("Lost executor %s on %s: %s".format(executorId, hostPort, reason))
+ logExecutorLoss(executorId, hostPort, reason)
removeExecutor(executorId, reason)
failedExecutor = Some(executorId)
} else {
- executorIdToHost.get(executorId) match {
- case Some(_) =>
- // If the host mapping still exists, it means we don't know the loss reason for the
- // executor. So call removeExecutor() to update tasks running on that executor when
- // the real loss reason is finally known.
- removeExecutor(executorId, reason)
-
- case None =>
- // We may get multiple executorLost() calls with different loss reasons. For example,
- // one may be triggered by a dropped connection from the slave while another may be a
- // report of executor termination from Mesos. We produce log messages for both so we
- // eventually report the termination reason.
- logError("Lost an executor " + executorId + " (already removed): " + reason)
- }
+ executorIdToHost.get(executorId) match {
+ case Some(hostPort) =>
+ // If the host mapping still exists, it means we don't know the loss reason for the
+ // executor. So call removeExecutor() to update tasks running on that executor when
+ // the real loss reason is finally known.
+ logExecutorLoss(executorId, hostPort, reason)
+ removeExecutor(executorId, reason)
+
+ case None =>
+ // We may get multiple executorLost() calls with different loss reasons. For example,
+ // one may be triggered by a dropped connection from the slave while another may be a
+ // report of executor termination from Mesos. We produce log messages for both so we
+ // eventually report the termination reason.
+ logError(s"Lost an executor $executorId (already removed): $reason")
+ }
}
}
// Call dagScheduler.executorLost without holding the lock on this to prevent deadlock
@@ -491,13 +498,26 @@ private[spark] class TaskSchedulerImpl(
}
}
+ private def logExecutorLoss(
+ executorId: String,
+ hostPort: String,
+ reason: ExecutorLossReason): Unit = reason match {
+ case LossReasonPending =>
+ logDebug(s"Executor $executorId on $hostPort lost, but reason not yet known.")
+ case ExecutorKilled =>
+ logInfo(s"Executor $executorId on $hostPort killed by driver.")
+ case _ =>
+ logError(s"Lost executor $executorId on $hostPort: $reason")
+ }
+
/**
* Remove an executor from all our data structures and mark it as lost. If the executor's loss
* reason is not yet known, do not yet remove its association with its host nor update the status
* of any running tasks, since the loss reason defines whether we'll fail those tasks.
*/
private def removeExecutor(executorId: String, reason: ExecutorLossReason) {
- activeExecutorIds -= executorId
+ executorIdToTaskCount -= executorId
+
val host = executorIdToHost(executorId)
val execs = executorsByHost.getOrElse(host, new HashSet)
execs -= executorId
@@ -534,7 +554,11 @@ private[spark] class TaskSchedulerImpl(
}
def isExecutorAlive(execId: String): Boolean = synchronized {
- activeExecutorIds.contains(execId)
+ executorIdToTaskCount.contains(execId)
+ }
+
+ def isExecutorBusy(execId: String): Boolean = synchronized {
+ executorIdToTaskCount.getOrElse(execId, -1) > 0
}
// By default, rack is unknown
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 114468c48c44..a02f3017cb6e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -800,6 +800,7 @@ private[spark] class TaskSetManager(
for ((tid, info) <- taskInfos if info.running && info.executorId == execId) {
val exitCausedByApp: Boolean = reason match {
case exited: ExecutorExited => exited.exitCausedByApp
+ case ExecutorKilled => false
case _ => true
}
handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure(info.executorId, exitCausedByApp,
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index f71d98feac05..505c161141c8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -64,8 +64,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
private val listenerBus = scheduler.sc.listenerBus
- // Executors we have requested the cluster manager to kill that have not died yet
- private val executorsPendingToRemove = new HashSet[String]
+ // Executors we have requested the cluster manager to kill that have not died yet; maps
+ // the executor ID to whether it was explicitly killed by the driver (and thus shouldn't
+ // be considered an app-related failure).
+ private val executorsPendingToRemove = new HashMap[String, Boolean]
// A map to store hostname with its possible task number running on it
protected var hostToLocalTaskCount: Map[String, Int] = Map.empty
@@ -250,15 +252,15 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
case Some(executorInfo) =>
// This must be synchronized because variables mutated
// in this block are read when requesting executors
- CoarseGrainedSchedulerBackend.this.synchronized {
+ val killed = CoarseGrainedSchedulerBackend.this.synchronized {
addressToExecutorId -= executorInfo.executorAddress
executorDataMap -= executorId
- executorsPendingToRemove -= executorId
executorsPendingLossReason -= executorId
+ executorsPendingToRemove.remove(executorId).getOrElse(false)
}
totalCoreCount.addAndGet(-executorInfo.totalCores)
totalRegisteredExecutors.addAndGet(-1)
- scheduler.executorLost(executorId, reason)
+ scheduler.executorLost(executorId, if (killed) ExecutorKilled else reason)
listenerBus.post(
SparkListenerExecutorRemoved(System.currentTimeMillis(), executorId, reason.toString))
case None => logInfo(s"Asked to remove non-existent executor $executorId")
@@ -269,7 +271,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
* Stop making resource offers for the given executor. The executor is marked as lost with
* the loss reason still pending.
*
- * @return Whether executor was alive.
+ * @return Whether executor should be disabled
*/
protected def disableExecutor(executorId: String): Boolean = {
val shouldDisable = CoarseGrainedSchedulerBackend.this.synchronized {
@@ -277,7 +279,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
executorsPendingLossReason += executorId
true
} else {
- false
+ // Returns true for explicitly killed executors, we also need to get pending loss reasons;
+ // For others return false.
+ executorsPendingToRemove.contains(executorId)
}
}
@@ -451,17 +455,25 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
* @return whether the kill request is acknowledged.
*/
final override def killExecutors(executorIds: Seq[String]): Boolean = synchronized {
- killExecutors(executorIds, replace = false)
+ killExecutors(executorIds, replace = false, force = false)
}
/**
* Request that the cluster manager kill the specified executors.
*
+ * When asking the executor to be replaced, the executor loss is considered a failure, and
+ * killed tasks that are running on the executor will count towards the failure limits. If no
+ * replacement is being requested, then the tasks will not count towards the limit.
+ *
* @param executorIds identifiers of executors to kill
* @param replace whether to replace the killed executors with new ones
+ * @param force whether to force kill busy executors
* @return whether the kill request is acknowledged.
*/
- final def killExecutors(executorIds: Seq[String], replace: Boolean): Boolean = synchronized {
+ final def killExecutors(
+ executorIds: Seq[String],
+ replace: Boolean,
+ force: Boolean): Boolean = synchronized {
logInfo(s"Requesting to kill executor(s) ${executorIds.mkString(", ")}")
val (knownExecutors, unknownExecutors) = executorIds.partition(executorDataMap.contains)
unknownExecutors.foreach { id =>
@@ -469,8 +481,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
}
// If an executor is already pending to be removed, do not kill it again (SPARK-9795)
- val executorsToKill = knownExecutors.filter { id => !executorsPendingToRemove.contains(id) }
- executorsPendingToRemove ++= executorsToKill
+ // If this executor is busy, do not kill it unless we are told to force kill it (SPARK-9552)
+ val executorsToKill = knownExecutors
+ .filter { id => !executorsPendingToRemove.contains(id) }
+ .filter { id => force || !scheduler.isExecutorBusy(id) }
+ executorsToKill.foreach { id => executorsPendingToRemove(id) = !replace }
// If we do not wish to replace the executors we kill, sync the target number of executors
// with the cluster manager to avoid allocating new ones. When computing the new target,
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
index 2de9b6a65169..7d08eae0b487 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
@@ -109,7 +109,7 @@ private[spark] class CoarseMesosSchedulerBackend(
private val mesosExternalShuffleClient: Option[MesosExternalShuffleClient] = {
if (shuffleServiceEnabled) {
Some(new MesosExternalShuffleClient(
- SparkTransportConf.fromSparkConf(conf),
+ SparkTransportConf.fromSparkConf(conf, "shuffle"),
securityManager,
securityManager.isAuthenticationEnabled(),
securityManager.isSaslEncryptionEnabled()))
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index bc51d4f2820c..d5ba690ed04b 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -17,7 +17,7 @@
package org.apache.spark.serializer
-import java.io.{EOFException, IOException, InputStream, OutputStream}
+import java.io.{DataInput, DataOutput, EOFException, IOException, InputStream, OutputStream}
import java.nio.ByteBuffer
import javax.annotation.Nullable
@@ -25,11 +25,12 @@ import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag
-import com.esotericsoftware.kryo.{Kryo, KryoException}
import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput}
import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer}
+import com.esotericsoftware.kryo.{Kryo, KryoException, Serializer => KryoClassSerializer}
import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator}
import org.apache.avro.generic.{GenericData, GenericRecord}
+import org.roaringbitmap.RoaringBitmap
import org.apache.spark._
import org.apache.spark.api.python.PythonBroadcast
@@ -37,8 +38,8 @@ import org.apache.spark.broadcast.HttpBroadcast
import org.apache.spark.network.util.ByteUnit
import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus}
import org.apache.spark.storage._
-import org.apache.spark.util.{Utils, BoundedPriorityQueue, SerializableConfiguration, SerializableJobConf}
-import org.apache.spark.util.collection.{BitSet, CompactBuffer}
+import org.apache.spark.util.collection.CompactBuffer
+import org.apache.spark.util.{BoundedPriorityQueue, SerializableConfiguration, SerializableJobConf, Utils}
/**
* A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]].
@@ -93,6 +94,9 @@ class KryoSerializer(conf: SparkConf)
for (cls <- KryoSerializer.toRegister) {
kryo.register(cls)
}
+ for ((cls, ser) <- KryoSerializer.toRegisterSerializer) {
+ kryo.register(cls, ser)
+ }
// For results returned by asJavaIterable. See JavaIterableWrapperSerializer.
kryo.register(JavaIterableWrapperSerializer.wrapperClass, new JavaIterableWrapperSerializer)
@@ -362,7 +366,6 @@ private[serializer] object KryoSerializer {
classOf[StorageLevel],
classOf[CompressedMapStatus],
classOf[HighlyCompressedMapStatus],
- classOf[BitSet],
classOf[CompactBuffer[_]],
classOf[BlockManagerId],
classOf[Array[Byte]],
@@ -371,6 +374,55 @@ private[serializer] object KryoSerializer {
classOf[BoundedPriorityQueue[_]],
classOf[SparkConf]
)
+
+ private val toRegisterSerializer = Map[Class[_], KryoClassSerializer[_]](
+ classOf[RoaringBitmap] -> new KryoClassSerializer[RoaringBitmap]() {
+ override def write(kryo: Kryo, output: KryoOutput, bitmap: RoaringBitmap): Unit = {
+ bitmap.serialize(new KryoOutputDataOutputBridge(output))
+ }
+ override def read(kryo: Kryo, input: KryoInput, cls: Class[RoaringBitmap]): RoaringBitmap = {
+ val ret = new RoaringBitmap
+ ret.deserialize(new KryoInputDataInputBridge(input))
+ ret
+ }
+ }
+ )
+}
+
+private[serializer] class KryoInputDataInputBridge(input: KryoInput) extends DataInput {
+ override def readLong(): Long = input.readLong()
+ override def readChar(): Char = input.readChar()
+ override def readFloat(): Float = input.readFloat()
+ override def readByte(): Byte = input.readByte()
+ override def readShort(): Short = input.readShort()
+ override def readUTF(): String = input.readString() // readString in kryo does utf8
+ override def readInt(): Int = input.readInt()
+ override def readUnsignedShort(): Int = input.readShortUnsigned()
+ override def skipBytes(n: Int): Int = input.skip(n.toLong).toInt
+ override def readFully(b: Array[Byte]): Unit = input.read(b)
+ override def readFully(b: Array[Byte], off: Int, len: Int): Unit = input.read(b, off, len)
+ override def readLine(): String = throw new UnsupportedOperationException("readLine")
+ override def readBoolean(): Boolean = input.readBoolean()
+ override def readUnsignedByte(): Int = input.readByteUnsigned()
+ override def readDouble(): Double = input.readDouble()
+}
+
+private[serializer] class KryoOutputDataOutputBridge(output: KryoOutput) extends DataOutput {
+ override def writeFloat(v: Float): Unit = output.writeFloat(v)
+ // There is no "readChars" counterpart, except maybe "readLine", which is not supported
+ override def writeChars(s: String): Unit = throw new UnsupportedOperationException("writeChars")
+ override def writeDouble(v: Double): Unit = output.writeDouble(v)
+ override def writeUTF(s: String): Unit = output.writeString(s) // writeString in kryo does UTF8
+ override def writeShort(v: Int): Unit = output.writeShort(v)
+ override def writeInt(v: Int): Unit = output.writeInt(v)
+ override def writeBoolean(v: Boolean): Unit = output.writeBoolean(v)
+ override def write(b: Int): Unit = output.write(b)
+ override def write(b: Array[Byte]): Unit = output.write(b)
+ override def write(b: Array[Byte], off: Int, len: Int): Unit = output.write(b, off, len)
+ override def writeBytes(s: String): Unit = output.writeString(s)
+ override def writeChar(v: Int): Unit = output.writeChar(v.toChar)
+ override def writeLong(v: Long): Unit = output.writeLong(v)
+ override def writeByte(v: Int): Unit = output.writeByte(v)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
index 39fadd878351..cc5f933393ad 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
@@ -46,7 +46,7 @@ private[spark] trait ShuffleWriterGroup {
private[spark] class FileShuffleBlockResolver(conf: SparkConf)
extends ShuffleBlockResolver with Logging {
- private val transportConf = SparkTransportConf.fromSparkConf(conf)
+ private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle")
private lazy val blockManager = SparkEnv.get.blockManager
diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
index 05b1eed7f3be..fadb8fe7ed0a 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
@@ -47,7 +47,7 @@ private[spark] class IndexShuffleBlockResolver(
private lazy val blockManager = Option(_blockManager).getOrElse(SparkEnv.get.blockManager)
- private val transportConf = SparkTransportConf.fromSparkConf(conf)
+ private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle")
def getDataFile(shuffleId: Int, mapId: Int): File = {
blockManager.diskBlockManager.getFile(ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID))
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 661c706af32b..ab0007fb7899 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -122,7 +122,7 @@ private[spark] class BlockManager(
// Client to read other executors' shuffle files. This is either an external service, or just the
// standard BlockTransferService to directly connect to other Executors.
private[spark] val shuffleClient = if (externalShuffleServiceEnabled) {
- val transConf = SparkTransportConf.fromSparkConf(conf, numUsableCores)
+ val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores)
new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled(),
securityManager.isSaslEncryptionEnabled())
} else {
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala
index be144f6065ba..1268f44596f8 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala
@@ -18,7 +18,7 @@
package org.apache.spark.ui.jobs
import scala.collection.mutable
-import scala.xml.Node
+import scala.xml.{Unparsed, Node}
import org.apache.spark.ui.{ToolTips, UIUtils}
import org.apache.spark.ui.jobs.UIData.StageUIData
@@ -52,7 +52,7 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: Stage
+
}
private def createExecutorTable() : Seq[Node] = {
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
index ea806d09b600..2a1c3c1a50ec 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
@@ -145,9 +145,22 @@ private[ui] class StageTableBase(
case None => "Unknown"
}
val finishTime = s.completionTime.getOrElse(System.currentTimeMillis)
- val duration = s.submissionTime.map { t =>
- if (finishTime > t) finishTime - t else System.currentTimeMillis - t
- }
+
+ // The submission time for a stage is misleading because it counts the time
+ // the stage waits to be launched. (SPARK-10930)
+ val taskLaunchTimes =
+ stageData.taskData.values.map(_.taskInfo.launchTime).filter(_ > 0)
+ val duration: Option[Long] =
+ if (taskLaunchTimes.nonEmpty) {
+ val startTime = taskLaunchTimes.min
+ if (finishTime > startTime) {
+ Some(finishTime - startTime)
+ } else {
+ Some(System.currentTimeMillis() - startTime)
+ }
+ } else {
+ None
+ }
val formattedDuration = duration.map(d => UIUtils.formatDuration(d)).getOrElse("Unknown")
val inputRead = stageData.inputBytes
diff --git a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala
index b3b54af972cb..6c1fca71f228 100644
--- a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala
+++ b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala
@@ -19,6 +19,7 @@ package org.apache.spark.util
import java.util.concurrent._
import java.util.concurrent.atomic.AtomicBoolean
+import scala.util.DynamicVariable
import org.apache.spark.SparkContext
@@ -60,22 +61,27 @@ private[spark] abstract class AsynchronousListenerBus[L <: AnyRef, E](name: Stri
private val listenerThread = new Thread(name) {
setDaemon(true)
override def run(): Unit = Utils.tryOrStopSparkContext(sparkContext) {
- while (true) {
- eventLock.acquire()
- self.synchronized {
- processingEvent = true
- }
- try {
- if (stopped.get()) {
- // Get out of the while loop and shutdown the daemon thread
- return
- }
- val event = eventQueue.poll
- assert(event != null, "event queue was empty but the listener bus was not stopped")
- postToAll(event)
- } finally {
+ AsynchronousListenerBus.withinListenerThread.withValue(true) {
+ while (true) {
+ eventLock.acquire()
self.synchronized {
- processingEvent = false
+ processingEvent = true
+ }
+ try {
+ val event = eventQueue.poll
+ if (event == null) {
+ // Get out of the while loop and shutdown the daemon thread
+ if (!stopped.get) {
+ throw new IllegalStateException("Polling `null` from eventQueue means" +
+ " the listener bus has been stopped. So `stopped` must be true")
+ }
+ return
+ }
+ postToAll(event)
+ } finally {
+ self.synchronized {
+ processingEvent = false
+ }
}
}
}
@@ -174,3 +180,10 @@ private[spark] abstract class AsynchronousListenerBus[L <: AnyRef, E](name: Stri
*/
def onDropEvent(event: E): Unit
}
+
+private[spark] object AsynchronousListenerBus {
+ /* Allows for Context to check whether stop() call is made within listener thread
+ */
+ val withinListenerThread: DynamicVariable[Boolean] = new DynamicVariable[Boolean](false)
+}
+
diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala
index 23ee4eff0881..09864e3f8392 100644
--- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala
+++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala
@@ -31,6 +31,21 @@ import org.apache.spark.Logging
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.collection.OpenHashSet
+/**
+ * A trait that allows a class to give [[SizeEstimator]] more accurate size estimation.
+ * When a class extends it, [[SizeEstimator]] will query the `estimatedSize` first.
+ * If `estimatedSize` does not return [[None]], [[SizeEstimator]] will use the returned size
+ * as the size of the object. Otherwise, [[SizeEstimator]] will do the estimation work.
+ * The difference between a [[KnownSizeEstimation]] and
+ * [[org.apache.spark.util.collection.SizeTracker]] is that, a
+ * [[org.apache.spark.util.collection.SizeTracker]] still uses [[SizeEstimator]] to
+ * estimate the size. However, a [[KnownSizeEstimation]] can provide a better estimation without
+ * using [[SizeEstimator]].
+ */
+private[spark] trait KnownSizeEstimation {
+ def estimatedSize: Long
+}
+
/**
* :: DeveloperApi ::
* Estimates the sizes of Java objects (number of bytes of memory they occupy), for use in
@@ -199,10 +214,15 @@ object SizeEstimator extends Logging {
// the size estimator since it references the whole REPL. Do nothing in this case. In
// general all ClassLoaders and Classes will be shared between objects anyway.
} else {
- val classInfo = getClassInfo(cls)
- state.size += alignSize(classInfo.shellSize)
- for (field <- classInfo.pointerFields) {
- state.enqueue(field.get(obj))
+ obj match {
+ case s: KnownSizeEstimation =>
+ state.size += s.estimatedSize
+ case _ =>
+ val classInfo = getClassInfo(cls)
+ state.size += alignSize(classInfo.shellSize)
+ for (field <- classInfo.pointerFields) {
+ state.enqueue(field.get(obj))
+ }
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala
index 724818724733..5e322557e964 100644
--- a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala
+++ b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala
@@ -29,7 +29,11 @@ private[spark] object SparkUncaughtExceptionHandler
override def uncaughtException(thread: Thread, exception: Throwable) {
try {
- logError("Uncaught exception in thread " + thread, exception)
+ // Make it explicit that uncaught exceptions are thrown when container is shutting down.
+ // It will help users when they analyze the executor logs
+ val inShutdownMsg = if (ShutdownHookManager.inShutdown()) "[Container in shutdown] " else ""
+ val errMsg = "Uncaught exception in thread "
+ logError(inShutdownMsg + errMsg + thread, exception)
// We may have been called from a shutdown hook. If so, we must not call System.exit().
// (If we do, we will deadlock.)
diff --git a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala
index 85c5bdbfcebc..7ab67fc3a2de 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala
@@ -17,21 +17,14 @@
package org.apache.spark.util.collection
-import java.io.{Externalizable, ObjectInput, ObjectOutput}
-
-import org.apache.spark.util.{Utils => UUtils}
-
-
/**
* A simple, fixed-size bit set implementation. This implementation is fast because it avoids
* safety/bound checking.
*/
-class BitSet(private[this] var numBits: Int) extends Externalizable {
+class BitSet(numBits: Int) extends Serializable {
- private var words = new Array[Long](bit2words(numBits))
- private def numWords = words.length
-
- def this() = this(0)
+ private val words = new Array[Long](bit2words(numBits))
+ private val numWords = words.length
/**
* Compute the capacity (number of bits) that can be represented
@@ -237,19 +230,4 @@ class BitSet(private[this] var numBits: Int) extends Externalizable {
/** Return the number of longs it would take to hold numBits. */
private def bit2words(numBits: Int) = ((numBits - 1) >> 6) + 1
-
- override def writeExternal(out: ObjectOutput): Unit = UUtils.tryOrIOException {
- out.writeInt(numBits)
- words.foreach(out.writeLong(_))
- }
-
- override def readExternal(in: ObjectInput): Unit = UUtils.tryOrIOException {
- numBits = in.readInt()
- words = new Array[Long](bit2words(numBits))
- var index = 0
- while (index < words.length) {
- words(index) = in.readLong()
- index += 1
- }
- }
}
diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
index 119e5fc28e41..ab23326c6c25 100644
--- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
@@ -21,17 +21,223 @@ import java.io.File
import scala.reflect.ClassTag
+import org.apache.spark.CheckpointSuite._
import org.apache.spark.rdd._
import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId}
import org.apache.spark.util.Utils
+trait RDDCheckpointTester { self: SparkFunSuite =>
+
+ protected val partitioner = new HashPartitioner(2)
+
+ private def defaultCollectFunc[T](rdd: RDD[T]): Any = rdd.collect()
+
+ /** Implementations of this trait must implement this method */
+ protected def sparkContext: SparkContext
+
+ /**
+ * Test checkpointing of the RDD generated by the given operation. It tests whether the
+ * serialized size of the RDD is reduce after checkpointing or not. This function should be called
+ * on all RDDs that have a parent RDD (i.e., do not call on ParallelCollection, BlockRDD, etc.).
+ *
+ * @param op an operation to run on the RDD
+ * @param reliableCheckpoint if true, use reliable checkpoints, otherwise use local checkpoints
+ * @param collectFunc a function for collecting the values in the RDD, in case there are
+ * non-comparable types like arrays that we want to convert to something
+ * that supports ==
+ */
+ protected def testRDD[U: ClassTag](
+ op: (RDD[Int]) => RDD[U],
+ reliableCheckpoint: Boolean,
+ collectFunc: RDD[U] => Any = defaultCollectFunc[U] _): Unit = {
+ // Generate the final RDD using given RDD operation
+ val baseRDD = generateFatRDD()
+ val operatedRDD = op(baseRDD)
+ val parentRDD = operatedRDD.dependencies.headOption.orNull
+ val rddType = operatedRDD.getClass.getSimpleName
+ val numPartitions = operatedRDD.partitions.length
+
+ // Force initialization of all the data structures in RDDs
+ // Without this, serializing the RDD will give a wrong estimate of the size of the RDD
+ initializeRdd(operatedRDD)
+
+ val partitionsBeforeCheckpoint = operatedRDD.partitions
+
+ // Find serialized sizes before and after the checkpoint
+ logInfo("RDD before checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString)
+ val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD)
+ checkpoint(operatedRDD, reliableCheckpoint)
+ val result = collectFunc(operatedRDD)
+ operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables
+ val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD)
+ logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString)
+
+ // Test whether the checkpoint file has been created
+ if (reliableCheckpoint) {
+ assert(
+ collectFunc(sparkContext.checkpointFile[U](operatedRDD.getCheckpointFile.get)) === result)
+ }
+
+ // Test whether dependencies have been changed from its earlier parent RDD
+ assert(operatedRDD.dependencies.head.rdd != parentRDD)
+
+ // Test whether the partitions have been changed from its earlier partitions
+ assert(operatedRDD.partitions.toList != partitionsBeforeCheckpoint.toList)
+
+ // Test whether the partitions have been changed to the new Hadoop partitions
+ assert(operatedRDD.partitions.toList === operatedRDD.checkpointData.get.getPartitions.toList)
+
+ // Test whether the number of partitions is same as before
+ assert(operatedRDD.partitions.length === numPartitions)
+
+ // Test whether the data in the checkpointed RDD is same as original
+ assert(collectFunc(operatedRDD) === result)
+
+ // Test whether serialized size of the RDD has reduced.
+ logInfo("Size of " + rddType +
+ " [" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]")
+ assert(
+ rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint,
+ "Size of " + rddType + " did not reduce after checkpointing " +
+ " [" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]"
+ )
+ }
+
+ /**
+ * Test whether checkpointing of the parent of the generated RDD also
+ * truncates the lineage or not. Some RDDs like CoGroupedRDD hold on to its parent
+ * RDDs partitions. So even if the parent RDD is checkpointed and its partitions changed,
+ * the generated RDD will remember the partitions and therefore potentially the whole lineage.
+ * This function should be called only those RDD whose partitions refer to parent RDD's
+ * partitions (i.e., do not call it on simple RDD like MappedRDD).
+ *
+ * @param op an operation to run on the RDD
+ * @param reliableCheckpoint if true, use reliable checkpoints, otherwise use local checkpoints
+ * @param collectFunc a function for collecting the values in the RDD, in case there are
+ * non-comparable types like arrays that we want to convert to something
+ * that supports ==
+ */
+ protected def testRDDPartitions[U: ClassTag](
+ op: (RDD[Int]) => RDD[U],
+ reliableCheckpoint: Boolean,
+ collectFunc: RDD[U] => Any = defaultCollectFunc[U] _): Unit = {
+ // Generate the final RDD using given RDD operation
+ val baseRDD = generateFatRDD()
+ val operatedRDD = op(baseRDD)
+ val parentRDDs = operatedRDD.dependencies.map(_.rdd)
+ val rddType = operatedRDD.getClass.getSimpleName
+
+ // Force initialization of all the data structures in RDDs
+ // Without this, serializing the RDD will give a wrong estimate of the size of the RDD
+ initializeRdd(operatedRDD)
+
+ // Find serialized sizes before and after the checkpoint
+ logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString)
+ val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD)
+ // checkpoint the parent RDD, not the generated one
+ parentRDDs.foreach { rdd =>
+ checkpoint(rdd, reliableCheckpoint)
+ }
+ val result = collectFunc(operatedRDD) // force checkpointing
+ operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables
+ val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD)
+ logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString)
+
+ // Test whether the data in the checkpointed RDD is same as original
+ assert(collectFunc(operatedRDD) === result)
+
+ // Test whether serialized size of the partitions has reduced
+ logInfo("Size of partitions of " + rddType +
+ " [" + partitionSizeBeforeCheckpoint + " --> " + partitionSizeAfterCheckpoint + "]")
+ assert(
+ partitionSizeAfterCheckpoint < partitionSizeBeforeCheckpoint,
+ "Size of " + rddType + " partitions did not reduce after checkpointing parent RDDs" +
+ " [" + partitionSizeBeforeCheckpoint + " --> " + partitionSizeAfterCheckpoint + "]"
+ )
+ }
+
+ /**
+ * Get serialized sizes of the RDD and its partitions, in order to test whether the size shrinks
+ * upon checkpointing. Ignores the checkpointData field, which may grow when we checkpoint.
+ */
+ private def getSerializedSizes(rdd: RDD[_]): (Int, Int) = {
+ val rddSize = Utils.serialize(rdd).size
+ val rddCpDataSize = Utils.serialize(rdd.checkpointData).size
+ val rddPartitionSize = Utils.serialize(rdd.partitions).size
+ val rddDependenciesSize = Utils.serialize(rdd.dependencies).size
+
+ // Print detailed size, helps in debugging
+ logInfo("Serialized sizes of " + rdd +
+ ": RDD = " + rddSize +
+ ", RDD checkpoint data = " + rddCpDataSize +
+ ", RDD partitions = " + rddPartitionSize +
+ ", RDD dependencies = " + rddDependenciesSize
+ )
+ // this makes sure that serializing the RDD's checkpoint data does not
+ // serialize the whole RDD as well
+ assert(
+ rddSize > rddCpDataSize,
+ "RDD's checkpoint data (" + rddCpDataSize + ") is equal or larger than the " +
+ "whole RDD with checkpoint data (" + rddSize + ")"
+ )
+ (rddSize - rddCpDataSize, rddPartitionSize)
+ }
+
+ /**
+ * Serialize and deserialize an object. This is useful to verify the objects
+ * contents after deserialization (e.g., the contents of an RDD split after
+ * it is sent to a slave along with a task)
+ */
+ protected def serializeDeserialize[T](obj: T): T = {
+ val bytes = Utils.serialize(obj)
+ Utils.deserialize[T](bytes)
+ }
+
+ /**
+ * Recursively force the initialization of the all members of an RDD and it parents.
+ */
+ private def initializeRdd(rdd: RDD[_]): Unit = {
+ rdd.partitions // forces the initialization of the partitions
+ rdd.dependencies.map(_.rdd).foreach(initializeRdd)
+ }
+
+ /** Checkpoint the RDD either locally or reliably. */
+ protected def checkpoint(rdd: RDD[_], reliableCheckpoint: Boolean): Unit = {
+ if (reliableCheckpoint) {
+ rdd.checkpoint()
+ } else {
+ rdd.localCheckpoint()
+ }
+ }
+
+ /** Run a test twice, once for local checkpointing and once for reliable checkpointing. */
+ protected def runTest(name: String)(body: Boolean => Unit): Unit = {
+ test(name + " [reliable checkpoint]")(body(true))
+ test(name + " [local checkpoint]")(body(false))
+ }
+
+ /**
+ * Generate an RDD such that both the RDD and its partitions have large size.
+ */
+ protected def generateFatRDD(): RDD[Int] = {
+ new FatRDD(sparkContext.makeRDD(1 to 100, 4)).map(x => x)
+ }
+
+ /**
+ * Generate an pair RDD (with partitioner) such that both the RDD and its partitions
+ * have large size.
+ */
+ protected def generateFatPairRDD(): RDD[(Int, Int)] = {
+ new FatPairRDD(sparkContext.makeRDD(1 to 100, 4), partitioner).mapValues(x => x)
+ }
+}
+
/**
* Test suite for end-to-end checkpointing functionality.
* This tests both reliable checkpoints and local checkpoints.
*/
-class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging {
+class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalSparkContext {
private var checkpointDir: File = _
- private val partitioner = new HashPartitioner(2)
override def beforeEach(): Unit = {
super.beforeEach()
@@ -46,6 +252,8 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging
Utils.deleteRecursively(checkpointDir)
}
+ override def sparkContext: SparkContext = sc
+
runTest("basic checkpointing") { reliableCheckpoint: Boolean =>
val parCollection = sc.makeRDD(1 to 4)
val flatMappedRDD = parCollection.flatMap(x => 1 to x)
@@ -250,204 +458,6 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging
assert(rdd.isCheckpointedAndMaterialized === true)
assert(rdd.partitions.size === 0)
}
-
- // Utility test methods
-
- /** Checkpoint the RDD either locally or reliably. */
- private def checkpoint(rdd: RDD[_], reliableCheckpoint: Boolean): Unit = {
- if (reliableCheckpoint) {
- rdd.checkpoint()
- } else {
- rdd.localCheckpoint()
- }
- }
-
- /** Run a test twice, once for local checkpointing and once for reliable checkpointing. */
- private def runTest(name: String)(body: Boolean => Unit): Unit = {
- test(name + " [reliable checkpoint]")(body(true))
- test(name + " [local checkpoint]")(body(false))
- }
-
- private def defaultCollectFunc[T](rdd: RDD[T]): Any = rdd.collect()
-
- /**
- * Test checkpointing of the RDD generated by the given operation. It tests whether the
- * serialized size of the RDD is reduce after checkpointing or not. This function should be called
- * on all RDDs that have a parent RDD (i.e., do not call on ParallelCollection, BlockRDD, etc.).
- *
- * @param op an operation to run on the RDD
- * @param reliableCheckpoint if true, use reliable checkpoints, otherwise use local checkpoints
- * @param collectFunc a function for collecting the values in the RDD, in case there are
- * non-comparable types like arrays that we want to convert to something that supports ==
- */
- private def testRDD[U: ClassTag](
- op: (RDD[Int]) => RDD[U],
- reliableCheckpoint: Boolean,
- collectFunc: RDD[U] => Any = defaultCollectFunc[U] _): Unit = {
- // Generate the final RDD using given RDD operation
- val baseRDD = generateFatRDD()
- val operatedRDD = op(baseRDD)
- val parentRDD = operatedRDD.dependencies.headOption.orNull
- val rddType = operatedRDD.getClass.getSimpleName
- val numPartitions = operatedRDD.partitions.length
-
- // Force initialization of all the data structures in RDDs
- // Without this, serializing the RDD will give a wrong estimate of the size of the RDD
- initializeRdd(operatedRDD)
-
- val partitionsBeforeCheckpoint = operatedRDD.partitions
-
- // Find serialized sizes before and after the checkpoint
- logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString)
- val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD)
- checkpoint(operatedRDD, reliableCheckpoint)
- val result = collectFunc(operatedRDD)
- operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables
- val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD)
- logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString)
-
- // Test whether the checkpoint file has been created
- if (reliableCheckpoint) {
- assert(collectFunc(sc.checkpointFile[U](operatedRDD.getCheckpointFile.get)) === result)
- }
-
- // Test whether dependencies have been changed from its earlier parent RDD
- assert(operatedRDD.dependencies.head.rdd != parentRDD)
-
- // Test whether the partitions have been changed from its earlier partitions
- assert(operatedRDD.partitions.toList != partitionsBeforeCheckpoint.toList)
-
- // Test whether the partitions have been changed to the new Hadoop partitions
- assert(operatedRDD.partitions.toList === operatedRDD.checkpointData.get.getPartitions.toList)
-
- // Test whether the number of partitions is same as before
- assert(operatedRDD.partitions.length === numPartitions)
-
- // Test whether the data in the checkpointed RDD is same as original
- assert(collectFunc(operatedRDD) === result)
-
- // Test whether serialized size of the RDD has reduced.
- logInfo("Size of " + rddType +
- " [" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]")
- assert(
- rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint,
- "Size of " + rddType + " did not reduce after checkpointing " +
- " [" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]"
- )
- }
-
- /**
- * Test whether checkpointing of the parent of the generated RDD also
- * truncates the lineage or not. Some RDDs like CoGroupedRDD hold on to its parent
- * RDDs partitions. So even if the parent RDD is checkpointed and its partitions changed,
- * the generated RDD will remember the partitions and therefore potentially the whole lineage.
- * This function should be called only those RDD whose partitions refer to parent RDD's
- * partitions (i.e., do not call it on simple RDD like MappedRDD).
- *
- * @param op an operation to run on the RDD
- * @param reliableCheckpoint if true, use reliable checkpoints, otherwise use local checkpoints
- * @param collectFunc a function for collecting the values in the RDD, in case there are
- * non-comparable types like arrays that we want to convert to something that supports ==
- */
- private def testRDDPartitions[U: ClassTag](
- op: (RDD[Int]) => RDD[U],
- reliableCheckpoint: Boolean,
- collectFunc: RDD[U] => Any = defaultCollectFunc[U] _): Unit = {
- // Generate the final RDD using given RDD operation
- val baseRDD = generateFatRDD()
- val operatedRDD = op(baseRDD)
- val parentRDDs = operatedRDD.dependencies.map(_.rdd)
- val rddType = operatedRDD.getClass.getSimpleName
-
- // Force initialization of all the data structures in RDDs
- // Without this, serializing the RDD will give a wrong estimate of the size of the RDD
- initializeRdd(operatedRDD)
-
- // Find serialized sizes before and after the checkpoint
- logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString)
- val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD)
- // checkpoint the parent RDD, not the generated one
- parentRDDs.foreach { rdd =>
- checkpoint(rdd, reliableCheckpoint)
- }
- val result = collectFunc(operatedRDD) // force checkpointing
- operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables
- val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD)
- logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString)
-
- // Test whether the data in the checkpointed RDD is same as original
- assert(collectFunc(operatedRDD) === result)
-
- // Test whether serialized size of the partitions has reduced
- logInfo("Size of partitions of " + rddType +
- " [" + partitionSizeBeforeCheckpoint + " --> " + partitionSizeAfterCheckpoint + "]")
- assert(
- partitionSizeAfterCheckpoint < partitionSizeBeforeCheckpoint,
- "Size of " + rddType + " partitions did not reduce after checkpointing parent RDDs" +
- " [" + partitionSizeBeforeCheckpoint + " --> " + partitionSizeAfterCheckpoint + "]"
- )
- }
-
- /**
- * Generate an RDD such that both the RDD and its partitions have large size.
- */
- private def generateFatRDD(): RDD[Int] = {
- new FatRDD(sc.makeRDD(1 to 100, 4)).map(x => x)
- }
-
- /**
- * Generate an pair RDD (with partitioner) such that both the RDD and its partitions
- * have large size.
- */
- private def generateFatPairRDD(): RDD[(Int, Int)] = {
- new FatPairRDD(sc.makeRDD(1 to 100, 4), partitioner).mapValues(x => x)
- }
-
- /**
- * Get serialized sizes of the RDD and its partitions, in order to test whether the size shrinks
- * upon checkpointing. Ignores the checkpointData field, which may grow when we checkpoint.
- */
- private def getSerializedSizes(rdd: RDD[_]): (Int, Int) = {
- val rddSize = Utils.serialize(rdd).size
- val rddCpDataSize = Utils.serialize(rdd.checkpointData).size
- val rddPartitionSize = Utils.serialize(rdd.partitions).size
- val rddDependenciesSize = Utils.serialize(rdd.dependencies).size
-
- // Print detailed size, helps in debugging
- logInfo("Serialized sizes of " + rdd +
- ": RDD = " + rddSize +
- ", RDD checkpoint data = " + rddCpDataSize +
- ", RDD partitions = " + rddPartitionSize +
- ", RDD dependencies = " + rddDependenciesSize
- )
- // this makes sure that serializing the RDD's checkpoint data does not
- // serialize the whole RDD as well
- assert(
- rddSize > rddCpDataSize,
- "RDD's checkpoint data (" + rddCpDataSize + ") is equal or larger than the " +
- "whole RDD with checkpoint data (" + rddSize + ")"
- )
- (rddSize - rddCpDataSize, rddPartitionSize)
- }
-
- /**
- * Serialize and deserialize an object. This is useful to verify the objects
- * contents after deserialization (e.g., the contents of an RDD split after
- * it is sent to a slave along with a task)
- */
- private def serializeDeserialize[T](obj: T): T = {
- val bytes = Utils.serialize(obj)
- Utils.deserialize[T](bytes)
- }
-
- /**
- * Recursively force the initialization of the all members of an RDD and it parents.
- */
- private def initializeRdd(rdd: RDD[_]): Unit = {
- rdd.partitions // forces the
- rdd.dependencies.map(_.rdd).foreach(initializeRdd)
- }
-
}
/** RDD partition that has large serialized size. */
@@ -494,5 +504,4 @@ object CheckpointSuite {
part
).asInstanceOf[RDD[(K, Array[Iterable[V]])]]
}
-
}
diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala
index 231f4631e0a4..1c775bcb3d9c 100644
--- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala
@@ -35,7 +35,7 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll {
var rpcHandler: ExternalShuffleBlockHandler = _
override def beforeAll() {
- val transportConf = SparkTransportConf.fromSparkConf(conf, numUsableCores = 2)
+ val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores = 2)
rpcHandler = new ExternalShuffleBlockHandler(transportConf, null)
val transportContext = new TransportContext(transportConf, rpcHandler)
server = transportContext.createServer()
diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
index 66a50512003d..d494b0caab85 100644
--- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
@@ -28,6 +28,7 @@ import org.scalatest.concurrent.Timeouts
import org.scalatest.time.SpanSugar._
import org.apache.spark._
+import org.apache.spark.api.r.RUtils
import org.apache.spark.deploy.SparkSubmit._
import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate
import org.apache.spark.util.{ResetSystemProperties, Utils}
@@ -368,10 +369,9 @@ class SparkSubmitSuite
}
}
- test("correctly builds R packages included in a jar with --packages") {
- // TODO(SPARK-9603): Building a package to $SPARK_HOME/R/lib is unavailable on Jenkins.
- // It's hard to write the test in SparkR (because we can't create the repository dynamically)
- /*
+ // TODO(SPARK-9603): Building a package is flaky on Jenkins Maven builds.
+ // See https://gist.github.com/shivaram/3a2fecce60768a603dac for a error log
+ ignore("correctly builds R packages included in a jar with --packages") {
assume(RUtils.isRInstalled, "R isn't installed on this machine.")
val main = MavenCoordinate("my.great.lib", "mylib", "0.1")
val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!"))
@@ -389,7 +389,6 @@ class SparkSubmitSuite
rScriptDir)
runSparkSubmit(args)
}
- */
}
test("resolves command line argument paths correctly") {
diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala
index d145e78834b1..2fa795f84666 100644
--- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala
@@ -17,10 +17,11 @@
package org.apache.spark.deploy
+import scala.collection.mutable
import scala.concurrent.duration._
import org.mockito.Mockito.{mock, when}
-import org.scalatest.BeforeAndAfterAll
+import org.scalatest.{BeforeAndAfterAll, PrivateMethodTester}
import org.scalatest.concurrent.Eventually._
import org.apache.spark._
@@ -29,6 +30,7 @@ import org.apache.spark.deploy.master.ApplicationInfo
import org.apache.spark.deploy.master.Master
import org.apache.spark.deploy.worker.Worker
import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv}
+import org.apache.spark.scheduler.TaskSchedulerImpl
import org.apache.spark.scheduler.cluster._
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RegisterExecutor
@@ -38,7 +40,8 @@ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RegisterE
class StandaloneDynamicAllocationSuite
extends SparkFunSuite
with LocalSparkContext
- with BeforeAndAfterAll {
+ with BeforeAndAfterAll
+ with PrivateMethodTester {
private val numWorkers = 2
private val conf = new SparkConf()
@@ -404,6 +407,41 @@ class StandaloneDynamicAllocationSuite
assert(apps.head.getExecutorLimit === 1)
}
+ test("disable force kill for busy executors (SPARK-9552)") {
+ sc = new SparkContext(appConf)
+ val appId = sc.applicationId
+ eventually(timeout(10.seconds), interval(10.millis)) {
+ val apps = getApplications()
+ assert(apps.size === 1)
+ assert(apps.head.id === appId)
+ assert(apps.head.executors.size === 2)
+ assert(apps.head.getExecutorLimit === Int.MaxValue)
+ }
+ var apps = getApplications()
+ // sync executors between the Master and the driver, needed because
+ // the driver refuses to kill executors it does not know about
+ syncExecutors(sc)
+ val executors = getExecutorIds(sc)
+ assert(executors.size === 2)
+
+ // simulate running a task on the executor
+ val getMap = PrivateMethod[mutable.HashMap[String, Int]]('executorIdToTaskCount)
+ val taskScheduler = sc.taskScheduler.asInstanceOf[TaskSchedulerImpl]
+ val executorIdToTaskCount = taskScheduler invokePrivate getMap()
+ executorIdToTaskCount(executors.head) = 1
+ // kill the busy executor without force; this should fail
+ assert(killExecutor(sc, executors.head, force = false))
+ apps = getApplications()
+ assert(apps.head.executors.size === 2)
+
+ // force kill busy executor
+ assert(killExecutor(sc, executors.head, force = true))
+ apps = getApplications()
+ // kill executor successfully
+ assert(apps.head.executors.size === 1)
+
+ }
+
// ===============================
// | Utility methods for testing |
// ===============================
@@ -455,6 +493,16 @@ class StandaloneDynamicAllocationSuite
sc.killExecutors(getExecutorIds(sc).take(n))
}
+ /** Kill the given executor, specifying whether to force kill it. */
+ private def killExecutor(sc: SparkContext, executorId: String, force: Boolean): Boolean = {
+ syncExecutors(sc)
+ sc.schedulerBackend match {
+ case b: CoarseGrainedSchedulerBackend =>
+ b.killExecutors(Seq(executorId), replace = false, force)
+ case _ => fail("expected coarse grained scheduler")
+ }
+ }
+
/**
* Return a list of executor IDs belonging to this application.
*
diff --git a/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala
index 34775577de8a..7a4472867568 100644
--- a/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala
@@ -63,56 +63,60 @@ class PersistenceEngineSuite extends SparkFunSuite {
conf: SparkConf, persistenceEngineCreator: Serializer => PersistenceEngine): Unit = {
val serializer = new JavaSerializer(conf)
val persistenceEngine = persistenceEngineCreator(serializer)
- persistenceEngine.persist("test_1", "test_1_value")
- assert(Seq("test_1_value") === persistenceEngine.read[String]("test_"))
- persistenceEngine.persist("test_2", "test_2_value")
- assert(Set("test_1_value", "test_2_value") === persistenceEngine.read[String]("test_").toSet)
- persistenceEngine.unpersist("test_1")
- assert(Seq("test_2_value") === persistenceEngine.read[String]("test_"))
- persistenceEngine.unpersist("test_2")
- assert(persistenceEngine.read[String]("test_").isEmpty)
-
- // Test deserializing objects that contain RpcEndpointRef
- val testRpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf))
try {
- // Create a real endpoint so that we can test RpcEndpointRef deserialization
- val workerEndpoint = testRpcEnv.setupEndpoint("worker", new RpcEndpoint {
- override val rpcEnv: RpcEnv = testRpcEnv
- })
-
- val workerToPersist = new WorkerInfo(
- id = "test_worker",
- host = "127.0.0.1",
- port = 10000,
- cores = 0,
- memory = 0,
- endpoint = workerEndpoint,
- webUiPort = 0,
- publicAddress = ""
- )
-
- persistenceEngine.addWorker(workerToPersist)
-
- val (storedApps, storedDrivers, storedWorkers) =
- persistenceEngine.readPersistedData(testRpcEnv)
-
- assert(storedApps.isEmpty)
- assert(storedDrivers.isEmpty)
-
- // Check deserializing WorkerInfo
- assert(storedWorkers.size == 1)
- val recoveryWorkerInfo = storedWorkers.head
- assert(workerToPersist.id === recoveryWorkerInfo.id)
- assert(workerToPersist.host === recoveryWorkerInfo.host)
- assert(workerToPersist.port === recoveryWorkerInfo.port)
- assert(workerToPersist.cores === recoveryWorkerInfo.cores)
- assert(workerToPersist.memory === recoveryWorkerInfo.memory)
- assert(workerToPersist.endpoint === recoveryWorkerInfo.endpoint)
- assert(workerToPersist.webUiPort === recoveryWorkerInfo.webUiPort)
- assert(workerToPersist.publicAddress === recoveryWorkerInfo.publicAddress)
+ persistenceEngine.persist("test_1", "test_1_value")
+ assert(Seq("test_1_value") === persistenceEngine.read[String]("test_"))
+ persistenceEngine.persist("test_2", "test_2_value")
+ assert(Set("test_1_value", "test_2_value") === persistenceEngine.read[String]("test_").toSet)
+ persistenceEngine.unpersist("test_1")
+ assert(Seq("test_2_value") === persistenceEngine.read[String]("test_"))
+ persistenceEngine.unpersist("test_2")
+ assert(persistenceEngine.read[String]("test_").isEmpty)
+
+ // Test deserializing objects that contain RpcEndpointRef
+ val testRpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf))
+ try {
+ // Create a real endpoint so that we can test RpcEndpointRef deserialization
+ val workerEndpoint = testRpcEnv.setupEndpoint("worker", new RpcEndpoint {
+ override val rpcEnv: RpcEnv = testRpcEnv
+ })
+
+ val workerToPersist = new WorkerInfo(
+ id = "test_worker",
+ host = "127.0.0.1",
+ port = 10000,
+ cores = 0,
+ memory = 0,
+ endpoint = workerEndpoint,
+ webUiPort = 0,
+ publicAddress = ""
+ )
+
+ persistenceEngine.addWorker(workerToPersist)
+
+ val (storedApps, storedDrivers, storedWorkers) =
+ persistenceEngine.readPersistedData(testRpcEnv)
+
+ assert(storedApps.isEmpty)
+ assert(storedDrivers.isEmpty)
+
+ // Check deserializing WorkerInfo
+ assert(storedWorkers.size == 1)
+ val recoveryWorkerInfo = storedWorkers.head
+ assert(workerToPersist.id === recoveryWorkerInfo.id)
+ assert(workerToPersist.host === recoveryWorkerInfo.host)
+ assert(workerToPersist.port === recoveryWorkerInfo.port)
+ assert(workerToPersist.cores === recoveryWorkerInfo.cores)
+ assert(workerToPersist.memory === recoveryWorkerInfo.memory)
+ assert(workerToPersist.endpoint === recoveryWorkerInfo.endpoint)
+ assert(workerToPersist.webUiPort === recoveryWorkerInfo.webUiPort)
+ assert(workerToPersist.publicAddress === recoveryWorkerInfo.publicAddress)
+ } finally {
+ testRpcEnv.shutdown()
+ testRpcEnv.awaitTermination()
+ }
} finally {
- testRpcEnv.shutdown()
- testRpcEnv.awaitTermination()
+ persistenceEngine.close()
}
}
diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
index 834e4743df86..2f55006420ce 100644
--- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
@@ -39,7 +39,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
override def beforeAll(): Unit = {
val conf = new SparkConf()
- env = createRpcEnv(conf, "local", 12345)
+ env = createRpcEnv(conf, "local", 0)
}
override def afterAll(): Unit = {
@@ -76,7 +76,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
}
})
- val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true)
+ val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true)
// Use anotherEnv to find out the RpcEndpointRef
val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "send-remotely")
try {
@@ -130,7 +130,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
}
})
- val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true)
+ val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true)
// Use anotherEnv to find out the RpcEndpointRef
val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-remotely")
try {
@@ -158,7 +158,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
val shortProp = "spark.rpc.short.timeout"
conf.set("spark.rpc.retry.wait", "0")
conf.set("spark.rpc.numRetries", "1")
- val anotherEnv = createRpcEnv(conf, "remote", 13345, clientMode = true)
+ val anotherEnv = createRpcEnv(conf, "remote", 0, clientMode = true)
// Use anotherEnv to find out the RpcEndpointRef
val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-timeout")
try {
@@ -417,7 +417,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
}
})
- val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true)
+ val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true)
// Use anotherEnv to find out the RpcEndpointRef
val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "sendWithReply-remotely")
try {
@@ -457,7 +457,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
}
})
- val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true)
+ val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true)
// Use anotherEnv to find out the RpcEndpointRef
val rpcEndpointRef = anotherEnv.setupEndpointRef(
"local", env.address, "sendWithReply-remotely-error")
@@ -497,7 +497,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
})
- val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true)
+ val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true)
// Use anotherEnv to find out the RpcEndpointRef
val rpcEndpointRef = anotherEnv.setupEndpointRef(
"local", env.address, "network-events")
@@ -543,7 +543,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
}
})
- val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true)
+ val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true)
// Use anotherEnv to find out the RpcEndpointRef
val rpcEndpointRef = anotherEnv.setupEndpointRef(
"local", env.address, "sendWithReply-unserializable-error")
@@ -571,8 +571,8 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
conf.set("spark.authenticate", "true")
conf.set("spark.authenticate.secret", "good")
- val localEnv = createRpcEnv(conf, "authentication-local", 13345)
- val remoteEnv = createRpcEnv(conf, "authentication-remote", 14345, clientMode = true)
+ val localEnv = createRpcEnv(conf, "authentication-local", 0)
+ val remoteEnv = createRpcEnv(conf, "authentication-remote", 0, clientMode = true)
try {
@volatile var message: String = null
@@ -602,8 +602,8 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
conf.set("spark.authenticate", "true")
conf.set("spark.authenticate.secret", "good")
- val localEnv = createRpcEnv(conf, "authentication-local", 13345)
- val remoteEnv = createRpcEnv(conf, "authentication-remote", 14345, clientMode = true)
+ val localEnv = createRpcEnv(conf, "authentication-local", 0)
+ val remoteEnv = createRpcEnv(conf, "authentication-remote", 0, clientMode = true)
try {
localEnv.setupEndpoint("ask-authentication", new RpcEndpoint {
diff --git a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala
index 6478ab51c4da..7aac02775e1b 100644
--- a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala
@@ -40,7 +40,7 @@ class AkkaRpcEnvSuite extends RpcEnvSuite {
})
val conf = new SparkConf()
val newRpcEnv = new AkkaRpcEnvFactory().create(
- RpcEnvConfig(conf, "test", "localhost", 12346, new SecurityManager(conf), false))
+ RpcEnvConfig(conf, "test", "localhost", 0, new SecurityManager(conf), false))
try {
val newRef = newRpcEnv.setupEndpointRef("local", ref.address, "test_endpoint")
assert(s"akka.tcp://local@${env.address}/user/test_endpoint" ===
@@ -59,7 +59,7 @@ class AkkaRpcEnvSuite extends RpcEnvSuite {
val conf = SSLSampleConfigs.sparkSSLConfig()
val securityManager = new SecurityManager(conf)
val rpcEnv = new AkkaRpcEnvFactory().create(
- RpcEnvConfig(conf, "test", "localhost", 12346, securityManager, false))
+ RpcEnvConfig(conf, "test", "localhost", 0, securityManager, false))
try {
val uri = rpcEnv.uriOf("local", RpcAddress("1.2.3.4", 12345), "test_endpoint")
assert("akka.ssl.tcp://local@1.2.3.4:12345/user/test_endpoint" === uri)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala
index b8e466fab450..15c8de61b824 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala
@@ -21,6 +21,7 @@ import org.apache.spark.storage.BlockManagerId
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.JavaSerializer
+import org.roaringbitmap.RoaringBitmap
import scala.util.Random
@@ -97,4 +98,34 @@ class MapStatusSuite extends SparkFunSuite {
val buf = ser.newInstance().serialize(status)
ser.newInstance().deserialize[MapStatus](buf)
}
+
+ test("RoaringBitmap: runOptimize succeeded") {
+ val r = new RoaringBitmap
+ (1 to 200000).foreach(i =>
+ if (i % 200 != 0) {
+ r.add(i)
+ }
+ )
+ val size1 = r.getSizeInBytes
+ val success = r.runOptimize()
+ r.trim()
+ val size2 = r.getSizeInBytes
+ assert(size1 > size2)
+ assert(success)
+ }
+
+ test("RoaringBitmap: runOptimize failed") {
+ val r = new RoaringBitmap
+ (1 to 200000).foreach(i =>
+ if (i % 200 == 0) {
+ r.add(i)
+ }
+ )
+ val size1 = r.getSizeInBytes
+ val success = r.runOptimize()
+ r.trim()
+ val size2 = r.getSizeInBytes
+ assert(size1 === size2)
+ assert(!success)
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
index 53102b9f1c93..84e545851f49 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
@@ -269,14 +269,15 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
}
test("onTaskGettingResult() called when result fetched remotely") {
- sc = new SparkContext("local", "SparkListenerSuite")
+ val conf = new SparkConf().set("spark.akka.frameSize", "1")
+ sc = new SparkContext("local", "SparkListenerSuite", conf)
val listener = new SaveTaskEvents
sc.addSparkListener(listener)
// Make a task whose result is larger than the akka frame size
- System.setProperty("spark.akka.frameSize", "1")
val akkaFrameSize =
sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt
+ assert(akkaFrameSize === 1024 * 1024)
val result = sc.parallelize(Seq(1), 1)
.map { x => 1.to(akkaFrameSize).toArray }
.reduce { case (x, y) => x }
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
index 815caa79ff52..bc72c3685e8c 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.scheduler
+import java.io.File
+import java.net.URL
import java.nio.ByteBuffer
import scala.concurrent.duration._
@@ -26,8 +28,10 @@ import scala.util.control.NonFatal
import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.Eventually._
-import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite}
+import org.apache.spark._
import org.apache.spark.storage.TaskResultBlockId
+import org.apache.spark.TestUtils.JavaSourceFromString
+import org.apache.spark.util.{MutableURLClassLoader, Utils}
/**
* Removes the TaskResult from the BlockManager before delegating to a normal TaskResultGetter.
@@ -119,5 +123,64 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local
// Make sure two tasks were run (one failed one, and a second retried one).
assert(scheduler.nextTaskId.get() === 2)
}
+
+ /**
+ * Make sure we are using the context classloader when deserializing failed TaskResults instead
+ * of the Spark classloader.
+
+ * This test compiles a jar containing an exception and tests that when it is thrown on the
+ * executor, enqueueFailedTask can correctly deserialize the failure and identify the thrown
+ * exception as the cause.
+
+ * Before this fix, enqueueFailedTask would throw a ClassNotFoundException when deserializing
+ * the exception, resulting in an UnknownReason for the TaskEndResult.
+ */
+ test("failed task deserialized with the correct classloader (SPARK-11195)") {
+ // compile a small jar containing an exception that will be thrown on an executor.
+ val tempDir = Utils.createTempDir()
+ val srcDir = new File(tempDir, "repro/")
+ srcDir.mkdirs()
+ val excSource = new JavaSourceFromString(new File(srcDir, "MyException").getAbsolutePath,
+ """package repro;
+ |
+ |public class MyException extends Exception {
+ |}
+ """.stripMargin)
+ val excFile = TestUtils.createCompiledClass("MyException", srcDir, excSource, Seq.empty)
+ val jarFile = new File(tempDir, "testJar-%s.jar".format(System.currentTimeMillis()))
+ TestUtils.createJar(Seq(excFile), jarFile, directoryPrefix = Some("repro"))
+
+ // ensure we reset the classloader after the test completes
+ val originalClassLoader = Thread.currentThread.getContextClassLoader
+ try {
+ // load the exception from the jar
+ val loader = new MutableURLClassLoader(new Array[URL](0), originalClassLoader)
+ loader.addURL(jarFile.toURI.toURL)
+ Thread.currentThread().setContextClassLoader(loader)
+ val excClass: Class[_] = Utils.classForName("repro.MyException")
+
+ // NOTE: we must run the cluster with "local" so that the executor can load the compiled
+ // jar.
+ sc = new SparkContext("local", "test", conf)
+ val rdd = sc.parallelize(Seq(1), 1).map { _ =>
+ val exc = excClass.newInstance().asInstanceOf[Exception]
+ throw exc
+ }
+
+ // the driver should not have any problems resolving the exception class and determining
+ // why the task failed.
+ val exceptionMessage = intercept[SparkException] {
+ rdd.collect()
+ }.getMessage
+
+ val expectedFailure = """(?s).*Lost task.*: repro.MyException.*""".r
+ val unknownFailure = """(?s).*Lost task.*: UnknownReason.*""".r
+
+ assert(expectedFailure.findFirstMatchIn(exceptionMessage).isDefined)
+ assert(unknownFailure.findFirstMatchIn(exceptionMessage).isEmpty)
+ } finally {
+ Thread.currentThread.setContextClassLoader(originalClassLoader)
+ }
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
index afe2e80358ca..e428414cf6e8 100644
--- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
@@ -322,6 +322,12 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext {
val conf = new SparkConf(false)
conf.set("spark.kryo.registrationRequired", "true")
+ // these cases require knowing the internals of RoaringBitmap a little. Blocks span 2^16
+ // values, and they use a bitmap (dense) if they have more than 4096 values, and an
+ // array (sparse) if they use less. So we just create two cases, one sparse and one dense.
+ // and we use a roaring bitmap for the empty blocks, so we trigger the dense case w/ mostly
+ // empty blocks
+
val ser = new KryoSerializer(conf).newInstance()
val denseBlockSizes = new Array[Long](5000)
val sparseBlockSizes = Array[Long](0L, 1L, 0L, 2L)
diff --git a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala
index 61601016e005..0af4b6098bb0 100644
--- a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala
@@ -340,10 +340,11 @@ class AkkaUtilsSuite extends SparkFunSuite with LocalSparkContext with ResetSyst
new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf))
val slaveConf = sparkSSLConfig()
+ .set("spark.rpc.askTimeout", "5s")
+ .set("spark.rpc.lookupTimeout", "5s")
val securityManagerBad = new SecurityManager(slaveConf)
val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, slaveConf, securityManagerBad)
- val slaveTracker = new MapOutputTrackerWorker(conf)
try {
slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME)
fail("should receive either ActorNotFound or TimeoutException")
diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala
index 20550178fb1b..101610e38014 100644
--- a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala
@@ -60,6 +60,12 @@ class DummyString(val arr: Array[Char]) {
@transient val hash32: Int = 0
}
+class DummyClass8 extends KnownSizeEstimation {
+ val x: Int = 0
+
+ override def estimatedSize: Long = 2015
+}
+
class SizeEstimatorSuite
extends SparkFunSuite
with BeforeAndAfterEach
@@ -214,4 +220,10 @@ class SizeEstimatorSuite
// Class should be 32 bytes on s390x if recognised as 64 bit platform
assertResult(32)(SizeEstimator.estimate(new DummyClass7))
}
+
+ test("SizeEstimation can provide the estimated size") {
+ // DummyClass8 provides its size estimation.
+ assertResult(2015)(SizeEstimator.estimate(new DummyClass8))
+ assertResult(20206)(SizeEstimator.estimate(Array.fill(10)(new DummyClass8)))
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala
index b0db0988eeaa..69dbfa9cd714 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala
@@ -17,10 +17,7 @@
package org.apache.spark.util.collection
-import java.io.{File, FileInputStream, FileOutputStream, ObjectInputStream, ObjectOutputStream}
-
import org.apache.spark.SparkFunSuite
-import org.apache.spark.util.{Utils => UUtils}
class BitSetSuite extends SparkFunSuite {
@@ -155,50 +152,4 @@ class BitSetSuite extends SparkFunSuite {
assert(bitsetDiff.nextSetBit(85) === 85)
assert(bitsetDiff.nextSetBit(86) === -1)
}
-
- test("read and write externally") {
- val tempDir = UUtils.createTempDir()
- val outputFile = File.createTempFile("bits", null, tempDir)
-
- val fos = new FileOutputStream(outputFile)
- val oos = new ObjectOutputStream(fos)
-
- // Create BitSet
- val setBits = Seq(0, 9, 1, 10, 90, 96)
- val bitset = new BitSet(100)
-
- for (i <- 0 until 100) {
- assert(!bitset.get(i))
- }
-
- setBits.foreach(i => bitset.set(i))
-
- for (i <- 0 until 100) {
- if (setBits.contains(i)) {
- assert(bitset.get(i))
- } else {
- assert(!bitset.get(i))
- }
- }
- assert(bitset.cardinality() === setBits.size)
-
- bitset.writeExternal(oos)
- oos.close()
-
- val fis = new FileInputStream(outputFile)
- val ois = new ObjectInputStream(fis)
-
- // Read BitSet from the file
- val bitset2 = new BitSet(0)
- bitset2.readExternal(ois)
-
- for (i <- 0 until 100) {
- if (setBits.contains(i)) {
- assert(bitset2.get(i))
- } else {
- assert(!bitset2.get(i))
- }
- }
- assert(bitset2.cardinality() === setBits.size)
- }
}
diff --git a/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala
index 164a7f396280..6eb6b3391a4a 100644
--- a/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala
+++ b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala
@@ -20,6 +20,8 @@ package org.apache.spark.sql.jdbc
import java.sql.Connection
import java.util.Properties
+import org.apache.spark.sql.Column
+import org.apache.spark.sql.catalyst.expressions.{Literal, If}
import org.apache.spark.tags.DockerTest
@DockerTest
@@ -37,28 +39,32 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
override def dataPreparation(conn: Connection): Unit = {
conn.prepareStatement("CREATE DATABASE foo").executeUpdate()
conn.setCatalog("foo")
- conn.prepareStatement("CREATE TABLE bar (a text, b integer, c double precision, d bigint, "
- + "e bit(1), f bit(10), g bytea, h boolean, i inet, j cidr)").executeUpdate()
+ conn.prepareStatement("CREATE TABLE bar (c0 text, c1 integer, c2 double precision, c3 bigint, "
+ + "c4 bit(1), c5 bit(10), c6 bytea, c7 boolean, c8 inet, c9 cidr, "
+ + "c10 integer[], c11 text[])").executeUpdate()
conn.prepareStatement("INSERT INTO bar VALUES ('hello', 42, 1.25, 123456789012345, B'0', "
- + "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16')").executeUpdate()
+ + "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16', "
+ + """'{1, 2}', '{"a", null, "b"}')""").executeUpdate()
}
test("Type mapping for various types") {
val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties)
val rows = df.collect()
assert(rows.length == 1)
- val types = rows(0).toSeq.map(x => x.getClass.toString)
- assert(types.length == 10)
- assert(types(0).equals("class java.lang.String"))
- assert(types(1).equals("class java.lang.Integer"))
- assert(types(2).equals("class java.lang.Double"))
- assert(types(3).equals("class java.lang.Long"))
- assert(types(4).equals("class java.lang.Boolean"))
- assert(types(5).equals("class [B"))
- assert(types(6).equals("class [B"))
- assert(types(7).equals("class java.lang.Boolean"))
- assert(types(8).equals("class java.lang.String"))
- assert(types(9).equals("class java.lang.String"))
+ val types = rows(0).toSeq.map(x => x.getClass)
+ assert(types.length == 12)
+ assert(classOf[String].isAssignableFrom(types(0)))
+ assert(classOf[java.lang.Integer].isAssignableFrom(types(1)))
+ assert(classOf[java.lang.Double].isAssignableFrom(types(2)))
+ assert(classOf[java.lang.Long].isAssignableFrom(types(3)))
+ assert(classOf[java.lang.Boolean].isAssignableFrom(types(4)))
+ assert(classOf[Array[Byte]].isAssignableFrom(types(5)))
+ assert(classOf[Array[Byte]].isAssignableFrom(types(6)))
+ assert(classOf[java.lang.Boolean].isAssignableFrom(types(7)))
+ assert(classOf[String].isAssignableFrom(types(8)))
+ assert(classOf[String].isAssignableFrom(types(9)))
+ assert(classOf[Seq[Int]].isAssignableFrom(types(10)))
+ assert(classOf[Seq[String]].isAssignableFrom(types(11)))
assert(rows(0).getString(0).equals("hello"))
assert(rows(0).getInt(1) == 42)
assert(rows(0).getDouble(2) == 1.25)
@@ -72,11 +78,17 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
assert(rows(0).getBoolean(7) == true)
assert(rows(0).getString(8) == "172.16.0.42")
assert(rows(0).getString(9) == "192.168.0.0/16")
+ assert(rows(0).getSeq(10) == Seq(1, 2))
+ assert(rows(0).getSeq(11) == Seq("a", null, "b"))
}
test("Basic write test") {
val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties)
- df.write.jdbc(jdbcUrl, "public.barcopy", new Properties)
// Test only that it doesn't crash.
+ df.write.jdbc(jdbcUrl, "public.barcopy", new Properties)
+ // Test write null values.
+ df.select(df.queryExecution.analyzed.output.map { a =>
+ Column(Literal.create(null, a.dataType)).as(a.name)
+ }: _*).write.jdbc(jdbcUrl, "public.barcopy2", new Properties)
}
}
diff --git a/docs/_data/menu-ml.yaml b/docs/_data/menu-ml.yaml
new file mode 100644
index 000000000000..dff3d33bf4ed
--- /dev/null
+++ b/docs/_data/menu-ml.yaml
@@ -0,0 +1,10 @@
+- text: Feature extraction, transformation, and selection
+ url: ml-features.html
+- text: Decision trees for classification and regression
+ url: ml-decision-tree.html
+- text: Ensembles
+ url: ml-ensembles.html
+- text: Linear methods with elastic-net regularization
+ url: ml-linear-methods.html
+- text: Multilayer perceptron classifier
+ url: ml-ann.html
diff --git a/docs/_data/menu-mllib.yaml b/docs/_data/menu-mllib.yaml
new file mode 100644
index 000000000000..12d22abd5282
--- /dev/null
+++ b/docs/_data/menu-mllib.yaml
@@ -0,0 +1,75 @@
+- text: Data types
+ url: mllib-data-types.html
+- text: Basic statistics
+ url: mllib-statistics.html
+ subitems:
+ - text: Summary statistics
+ url: mllib-statistics.html#summary-statistics
+ - text: Correlations
+ url: mllib-statistics.html#correlations
+ - text: Stratified sampling
+ url: mllib-statistics.html#stratified-sampling
+ - text: Hypothesis testing
+ url: mllib-statistics.html#hypothesis-testing
+ - text: Random data generation
+ url: mllib-statistics.html#random-data-generation
+- text: Classification and regression
+ url: mllib-classification-regression.html
+ subitems:
+ - text: Linear models (SVMs, logistic regression, linear regression)
+ url: mllib-linear-methods.html
+ - text: Naive Bayes
+ url: mllib-naive-bayes.html
+ - text: decision trees
+ url: mllib-decision-tree.html
+ - text: ensembles of trees (Random Forests and Gradient-Boosted Trees)
+ url: mllib-ensembles.html
+ - text: isotonic regression
+ url: mllib-isotonic-regression.html
+- text: Collaborative filtering
+ url: mllib-collaborative-filtering.html
+ subitems:
+ - text: alternating least squares (ALS)
+ url: mllib-collaborative-filtering.html#collaborative-filtering
+- text: Clustering
+ url: mllib-clustering.html
+ subitems:
+ - text: k-means
+ url: mllib-clustering.html#k-means
+ - text: Gaussian mixture
+ url: mllib-clustering.html#gaussian-mixture
+ - text: power iteration clustering (PIC)
+ url: mllib-clustering.html#power-iteration-clustering-pic
+ - text: latent Dirichlet allocation (LDA)
+ url: mllib-clustering.html#latent-dirichlet-allocation-lda
+ - text: streaming k-means
+ url: mllib-clustering.html#streaming-k-means
+- text: Dimensionality reduction
+ url: mllib-dimensionality-reduction.html
+ subitems:
+ - text: singular value decomposition (SVD)
+ url: mllib-dimensionality-reduction.html#singular-value-decomposition-svd
+ - text: principal component analysis (PCA)
+ url: mllib-dimensionality-reduction.html#principal-component-analysis-pca
+- text: Feature extraction and transformation
+ url: mllib-feature-extraction.html
+- text: Frequent pattern mining
+ url: mllib-frequent-pattern-mining.html
+ subitems:
+ - text: FP-growth
+ url: mllib-frequent-pattern-mining.html#fp-growth
+ - text: association rules
+ url: mllib-frequent-pattern-mining.html#association-rules
+ - text: PrefixSpan
+ url: mllib-frequent-pattern-mining.html#prefix-span
+- text: Evaluation metrics
+ url: mllib-evaluation-metrics.html
+- text: PMML model export
+ url: mllib-pmml-model-export.html
+- text: Optimization (developer)
+ url: mllib-optimization.html
+ subitems:
+ - text: stochastic gradient descent
+ url: mllib-optimization.html#stochastic-gradient-descent-sgd
+ - text: limited-memory BFGS (L-BFGS)
+ url: mllib-optimization.html#limited-memory-bfgs-l-bfgs
diff --git a/docs/_includes/nav-left-wrapper-ml.html b/docs/_includes/nav-left-wrapper-ml.html
new file mode 100644
index 000000000000..0103e890cc21
--- /dev/null
+++ b/docs/_includes/nav-left-wrapper-ml.html
@@ -0,0 +1,8 @@
+
+
+
spark.ml package
+ {% include nav-left.html nav=include.nav-ml %}
+
spark.mllib package
+ {% include nav-left.html nav=include.nav-mllib %}
+
+
\ No newline at end of file
diff --git a/docs/_includes/nav-left.html b/docs/_includes/nav-left.html
new file mode 100644
index 000000000000..73176f413255
--- /dev/null
+++ b/docs/_includes/nav-left.html
@@ -0,0 +1,17 @@
+{% assign navurl = page.url | remove: 'index.html' %}
+
diff --git a/docs/configuration.md b/docs/configuration.md
index c276e8e90dec..c496146e3ed6 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -305,7 +305,7 @@ Apart from these, the following properties are also available, and may be useful
daily
Set the time interval by which the executor logs will be rolled over.
- Rolling is disabled by default. Valid values are daily, hourly, minutely or
+ Rolling is disabled by default. Valid values are daily, hourly, minutely or
any interval in seconds. See spark.executor.logs.rolling.maxRetainedFiles
for automatic cleaning of old logs.
@@ -330,13 +330,13 @@ Apart from these, the following properties are also available, and may be useful
spark.python.profile
false
- Enable profiling in Python worker, the profile result will show up by sc.show_profiles(),
+ Enable profiling in Python worker, the profile result will show up by sc.show_profiles(),
or it will be displayed before the driver exiting. It also can be dumped into disk by
- sc.dump_profiles(path). If some of the profile results had been displayed manually,
+ sc.dump_profiles(path). If some of the profile results had been displayed manually,
they will not be displayed automatically before driver exiting.
- By default the pyspark.profiler.BasicProfiler will be used, but this can be overridden by
- passing a profiler class in as a parameter to the SparkContext constructor.
+ By default the pyspark.profiler.BasicProfiler will be used, but this can be overridden by
+ passing a profiler class in as a parameter to the SparkContext constructor.
@@ -722,17 +722,20 @@ Apart from these, the following properties are also available, and may be useful
Fraction of the heap space used for execution and storage. The lower this is, the more
frequently spills and cached data eviction occur. The purpose of this config is to set
aside memory for internal metadata, user data structures, and imprecise size estimation
- in the case of sparse, unusually large records.
+ in the case of sparse, unusually large records. Leaving this at the default value is
+ recommended. For more detail, see
+ this description.
spark.memory.storageFraction
0.5
- T​he size of the storage region within the space set aside by
- s​park.memory.fraction. This region is not statically reserved, but dynamically
- allocated as cache requests come in. ​Cached data may be evicted only if total storage exceeds
- this region.
+ Amount of storage memory immune to eviction, expressed as a fraction of the size of the
+ region set aside by s​park.memory.fraction. The higher this is, the less
+ working memory may be available to execution and tasks may spill to disk more often.
+ Leaving this at the default value is recommended. For more detail, see
+ this description.
diff --git a/docs/css/main.css b/docs/css/main.css
index d770173be101..356b324d6303 100755
--- a/docs/css/main.css
+++ b/docs/css/main.css
@@ -39,8 +39,18 @@
margin-left: 10px;
}
+body .container-wrapper {
+ position: absolute;
+ width: 100%;
+ display: flex;
+}
+
body #content {
+ position: relative;
+
line-height: 1.6; /* Inspired by Github's wiki style */
+ background-color: white;
+ padding-left: 15px;
}
.title {
@@ -155,3 +165,30 @@ ul.nav li.dropdown ul.dropdown-menu li.dropdown-submenu ul.dropdown-menu {
* AnchorJS (anchor links when hovering over headers)
*/
a.anchorjs-link:hover { text-decoration: none; }
+
+
+/**
+ * The left navigation bar.
+ */
+.left-menu-wrapper {
+ position: absolute;
+ height: 100%;
+
+ width: 256px;
+ margin-top: -20px;
+ padding-top: 20px;
+ background-color: #F0F8FC;
+}
+
+.left-menu {
+ position: fixed;
+ max-width: 350px;
+
+ padding-right: 10px;
+ width: 256px;
+}
+
+.left-menu h3 {
+ margin-left: 10px;
+ line-height: 30px;
+}
\ No newline at end of file
diff --git a/docs/job-scheduling.md b/docs/job-scheduling.md
index a3c34cb6796f..36327c6efeaf 100644
--- a/docs/job-scheduling.md
+++ b/docs/job-scheduling.md
@@ -47,7 +47,7 @@ application is not running tasks on a machine, other applications may run tasks
is useful when you expect large numbers of not overly active applications, such as shell sessions from
separate users. However, it comes with a risk of less predictable latency, because it may take a while for
an application to gain back cores on one node when it has work to do. To use this mode, simply use a
-`mesos://` URL without setting `spark.mesos.coarse` to true.
+`mesos://` URL and set `spark.mesos.coarse` to false.
Note that none of the modes currently provide memory sharing across applications. If you would like to share
data this way, we recommend running a single server application that can serve multiple requests by querying
diff --git a/docs/ml-ensembles.md b/docs/ml-ensembles.md
index ce15f5e6466e..f6c3c30d5334 100644
--- a/docs/ml-ensembles.md
+++ b/docs/ml-ensembles.md
@@ -115,194 +115,21 @@ We use two feature transformers to prepare the data; these help index categories
Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classification.RandomForestClassifier) for more details.
-{% highlight scala %}
-import org.apache.spark.ml.Pipeline
-import org.apache.spark.ml.classification.RandomForestClassifier
-import org.apache.spark.ml.classification.RandomForestClassificationModel
-import org.apache.spark.ml.feature.{StringIndexer, IndexToString, VectorIndexer}
-import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
-
-// Load and parse the data file, converting it to a DataFrame.
-val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")
-
-// Index labels, adding metadata to the label column.
-// Fit on whole dataset to include all labels in index.
-val labelIndexer = new StringIndexer()
- .setInputCol("label")
- .setOutputCol("indexedLabel")
- .fit(data)
-// Automatically identify categorical features, and index them.
-// Set maxCategories so features with > 4 distinct values are treated as continuous.
-val featureIndexer = new VectorIndexer()
- .setInputCol("features")
- .setOutputCol("indexedFeatures")
- .setMaxCategories(4)
- .fit(data)
-
-// Split the data into training and test sets (30% held out for testing)
-val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))
-
-// Train a RandomForest model.
-val rf = new RandomForestClassifier()
- .setLabelCol("indexedLabel")
- .setFeaturesCol("indexedFeatures")
- .setNumTrees(10)
-
-// Convert indexed labels back to original labels.
-val labelConverter = new IndexToString()
- .setInputCol("prediction")
- .setOutputCol("predictedLabel")
- .setLabels(labelIndexer.labels)
-
-// Chain indexers and forest in a Pipeline
-val pipeline = new Pipeline()
- .setStages(Array(labelIndexer, featureIndexer, rf, labelConverter))
-
-// Train model. This also runs the indexers.
-val model = pipeline.fit(trainingData)
-
-// Make predictions.
-val predictions = model.transform(testData)
-
-// Select example rows to display.
-predictions.select("predictedLabel", "label", "features").show(5)
-
-// Select (prediction, true label) and compute test error
-val evaluator = new MulticlassClassificationEvaluator()
- .setLabelCol("indexedLabel")
- .setPredictionCol("prediction")
- .setMetricName("precision")
-val accuracy = evaluator.evaluate(predictions)
-println("Test Error = " + (1.0 - accuracy))
-
-val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel]
-println("Learned classification forest model:\n" + rfModel.toDebugString)
-{% endhighlight %}
+{% include_example scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala %}
Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/RandomForestClassifier.html) for more details.
-{% highlight java %}
-import org.apache.spark.ml.Pipeline;
-import org.apache.spark.ml.PipelineModel;
-import org.apache.spark.ml.PipelineStage;
-import org.apache.spark.ml.classification.RandomForestClassifier;
-import org.apache.spark.ml.classification.RandomForestClassificationModel;
-import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
-import org.apache.spark.ml.feature.*;
-import org.apache.spark.sql.DataFrame;
-
-// Load and parse the data file, converting it to a DataFrame.
-DataFrame data = sqlContext.read().format("libsvm")
- .load("data/mllib/sample_libsvm_data.txt");
-
-// Index labels, adding metadata to the label column.
-// Fit on whole dataset to include all labels in index.
-StringIndexerModel labelIndexer = new StringIndexer()
- .setInputCol("label")
- .setOutputCol("indexedLabel")
- .fit(data);
-// Automatically identify categorical features, and index them.
-// Set maxCategories so features with > 4 distinct values are treated as continuous.
-VectorIndexerModel featureIndexer = new VectorIndexer()
- .setInputCol("features")
- .setOutputCol("indexedFeatures")
- .setMaxCategories(4)
- .fit(data);
-
-// Split the data into training and test sets (30% held out for testing)
-DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3});
-DataFrame trainingData = splits[0];
-DataFrame testData = splits[1];
-
-// Train a RandomForest model.
-RandomForestClassifier rf = new RandomForestClassifier()
- .setLabelCol("indexedLabel")
- .setFeaturesCol("indexedFeatures");
-
-// Convert indexed labels back to original labels.
-IndexToString labelConverter = new IndexToString()
- .setInputCol("prediction")
- .setOutputCol("predictedLabel")
- .setLabels(labelIndexer.labels());
-
-// Chain indexers and forest in a Pipeline
-Pipeline pipeline = new Pipeline()
- .setStages(new PipelineStage[] {labelIndexer, featureIndexer, rf, labelConverter});
-
-// Train model. This also runs the indexers.
-PipelineModel model = pipeline.fit(trainingData);
-
-// Make predictions.
-DataFrame predictions = model.transform(testData);
-
-// Select example rows to display.
-predictions.select("predictedLabel", "label", "features").show(5);
-
-// Select (prediction, true label) and compute test error
-MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
- .setLabelCol("indexedLabel")
- .setPredictionCol("prediction")
- .setMetricName("precision");
-double accuracy = evaluator.evaluate(predictions);
-System.out.println("Test Error = " + (1.0 - accuracy));
-
-RandomForestClassificationModel rfModel =
- (RandomForestClassificationModel)(model.stages()[2]);
-System.out.println("Learned classification forest model:\n" + rfModel.toDebugString());
-{% endhighlight %}
+{% include_example java/org/apache/spark/examples/ml/JavaRandomForestClassifierExample.java %}
Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classification.RandomForestClassifier) for more details.
-{% highlight python %}
-from pyspark.ml import Pipeline
-from pyspark.ml.classification import RandomForestClassifier
-from pyspark.ml.feature import StringIndexer, VectorIndexer
-from pyspark.ml.evaluation import MulticlassClassificationEvaluator
-
-# Load and parse the data file, converting it to a DataFrame.
-data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")
-
-# Index labels, adding metadata to the label column.
-# Fit on whole dataset to include all labels in index.
-labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data)
-# Automatically identify categorical features, and index them.
-# Set maxCategories so features with > 4 distinct values are treated as continuous.
-featureIndexer =\
- VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data)
-
-# Split the data into training and test sets (30% held out for testing)
-(trainingData, testData) = data.randomSplit([0.7, 0.3])
-
-# Train a RandomForest model.
-rf = RandomForestClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures")
-
-# Chain indexers and forest in a Pipeline
-pipeline = Pipeline(stages=[labelIndexer, featureIndexer, rf])
-
-# Train model. This also runs the indexers.
-model = pipeline.fit(trainingData)
-
-# Make predictions.
-predictions = model.transform(testData)
-
-# Select example rows to display.
-predictions.select("prediction", "indexedLabel", "features").show(5)
-
-# Select (prediction, true label) and compute test error
-evaluator = MulticlassClassificationEvaluator(
- labelCol="indexedLabel", predictionCol="prediction", metricName="precision")
-accuracy = evaluator.evaluate(predictions)
-print "Test Error = %g" % (1.0 - accuracy)
-
-rfModel = model.stages[2]
-print rfModel # summary only
-{% endhighlight %}
+{% include_example python/ml/random_forest_classifier_example.py %}
@@ -316,167 +143,21 @@ We use a feature transformer to index categorical features, adding metadata to t
Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.regression.RandomForestRegressor) for more details.
-{% highlight scala %}
-import org.apache.spark.ml.Pipeline
-import org.apache.spark.ml.regression.RandomForestRegressor
-import org.apache.spark.ml.regression.RandomForestRegressionModel
-import org.apache.spark.ml.feature.VectorIndexer
-import org.apache.spark.ml.evaluation.RegressionEvaluator
-
-// Load and parse the data file, converting it to a DataFrame.
-val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")
-
-// Automatically identify categorical features, and index them.
-// Set maxCategories so features with > 4 distinct values are treated as continuous.
-val featureIndexer = new VectorIndexer()
- .setInputCol("features")
- .setOutputCol("indexedFeatures")
- .setMaxCategories(4)
- .fit(data)
-
-// Split the data into training and test sets (30% held out for testing)
-val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))
-
-// Train a RandomForest model.
-val rf = new RandomForestRegressor()
- .setLabelCol("label")
- .setFeaturesCol("indexedFeatures")
-
-// Chain indexer and forest in a Pipeline
-val pipeline = new Pipeline()
- .setStages(Array(featureIndexer, rf))
-
-// Train model. This also runs the indexer.
-val model = pipeline.fit(trainingData)
-
-// Make predictions.
-val predictions = model.transform(testData)
-
-// Select example rows to display.
-predictions.select("prediction", "label", "features").show(5)
-
-// Select (prediction, true label) and compute test error
-val evaluator = new RegressionEvaluator()
- .setLabelCol("label")
- .setPredictionCol("prediction")
- .setMetricName("rmse")
-val rmse = evaluator.evaluate(predictions)
-println("Root Mean Squared Error (RMSE) on test data = " + rmse)
-
-val rfModel = model.stages(1).asInstanceOf[RandomForestRegressionModel]
-println("Learned regression forest model:\n" + rfModel.toDebugString)
-{% endhighlight %}
+{% include_example scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala %}
Refer to the [Java API docs](api/java/org/apache/spark/ml/regression/RandomForestRegressor.html) for more details.
-{% highlight java %}
-import org.apache.spark.ml.Pipeline;
-import org.apache.spark.ml.PipelineModel;
-import org.apache.spark.ml.PipelineStage;
-import org.apache.spark.ml.evaluation.RegressionEvaluator;
-import org.apache.spark.ml.feature.VectorIndexer;
-import org.apache.spark.ml.feature.VectorIndexerModel;
-import org.apache.spark.ml.regression.RandomForestRegressionModel;
-import org.apache.spark.ml.regression.RandomForestRegressor;
-import org.apache.spark.sql.DataFrame;
-
-// Load and parse the data file, converting it to a DataFrame.
-DataFrame data = sqlContext.read().format("libsvm")
- .load("data/mllib/sample_libsvm_data.txt");
-
-// Automatically identify categorical features, and index them.
-// Set maxCategories so features with > 4 distinct values are treated as continuous.
-VectorIndexerModel featureIndexer = new VectorIndexer()
- .setInputCol("features")
- .setOutputCol("indexedFeatures")
- .setMaxCategories(4)
- .fit(data);
-
-// Split the data into training and test sets (30% held out for testing)
-DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3});
-DataFrame trainingData = splits[0];
-DataFrame testData = splits[1];
-
-// Train a RandomForest model.
-RandomForestRegressor rf = new RandomForestRegressor()
- .setLabelCol("label")
- .setFeaturesCol("indexedFeatures");
-
-// Chain indexer and forest in a Pipeline
-Pipeline pipeline = new Pipeline()
- .setStages(new PipelineStage[] {featureIndexer, rf});
-
-// Train model. This also runs the indexer.
-PipelineModel model = pipeline.fit(trainingData);
-
-// Make predictions.
-DataFrame predictions = model.transform(testData);
-
-// Select example rows to display.
-predictions.select("prediction", "label", "features").show(5);
-
-// Select (prediction, true label) and compute test error
-RegressionEvaluator evaluator = new RegressionEvaluator()
- .setLabelCol("label")
- .setPredictionCol("prediction")
- .setMetricName("rmse");
-double rmse = evaluator.evaluate(predictions);
-System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse);
-
-RandomForestRegressionModel rfModel =
- (RandomForestRegressionModel)(model.stages()[1]);
-System.out.println("Learned regression forest model:\n" + rfModel.toDebugString());
-{% endhighlight %}
+{% include_example java/org/apache/spark/examples/ml/JavaRandomForestRegressorExample.java %}
Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.regression.RandomForestRegressor) for more details.
-{% highlight python %}
-from pyspark.ml import Pipeline
-from pyspark.ml.regression import RandomForestRegressor
-from pyspark.ml.feature import VectorIndexer
-from pyspark.ml.evaluation import RegressionEvaluator
-
-# Load and parse the data file, converting it to a DataFrame.
-data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")
-
-# Automatically identify categorical features, and index them.
-# Set maxCategories so features with > 4 distinct values are treated as continuous.
-featureIndexer =\
- VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data)
-
-# Split the data into training and test sets (30% held out for testing)
-(trainingData, testData) = data.randomSplit([0.7, 0.3])
-
-# Train a RandomForest model.
-rf = RandomForestRegressor(featuresCol="indexedFeatures")
-
-# Chain indexer and forest in a Pipeline
-pipeline = Pipeline(stages=[featureIndexer, rf])
-
-# Train model. This also runs the indexer.
-model = pipeline.fit(trainingData)
-
-# Make predictions.
-predictions = model.transform(testData)
-
-# Select example rows to display.
-predictions.select("prediction", "label", "features").show(5)
-
-# Select (prediction, true label) and compute test error
-evaluator = RegressionEvaluator(
- labelCol="label", predictionCol="prediction", metricName="rmse")
-rmse = evaluator.evaluate(predictions)
-print "Root Mean Squared Error (RMSE) on test data = %g" % rmse
-
-rfModel = model.stages[1]
-print rfModel # summary only
-{% endhighlight %}
+{% include_example python/ml/random_forest_regressor_example.py %}
@@ -560,194 +241,21 @@ We use two feature transformers to prepare the data; these help index categories
Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classification.GBTClassifier) for more details.
-{% highlight scala %}
-import org.apache.spark.ml.Pipeline
-import org.apache.spark.ml.classification.GBTClassifier
-import org.apache.spark.ml.classification.GBTClassificationModel
-import org.apache.spark.ml.feature.{StringIndexer, IndexToString, VectorIndexer}
-import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
-
-// Load and parse the data file, converting it to a DataFrame.
-val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")
-
-// Index labels, adding metadata to the label column.
-// Fit on whole dataset to include all labels in index.
-val labelIndexer = new StringIndexer()
- .setInputCol("label")
- .setOutputCol("indexedLabel")
- .fit(data)
-// Automatically identify categorical features, and index them.
-// Set maxCategories so features with > 4 distinct values are treated as continuous.
-val featureIndexer = new VectorIndexer()
- .setInputCol("features")
- .setOutputCol("indexedFeatures")
- .setMaxCategories(4)
- .fit(data)
-
-// Split the data into training and test sets (30% held out for testing)
-val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))
-
-// Train a GBT model.
-val gbt = new GBTClassifier()
- .setLabelCol("indexedLabel")
- .setFeaturesCol("indexedFeatures")
- .setMaxIter(10)
-
-// Convert indexed labels back to original labels.
-val labelConverter = new IndexToString()
- .setInputCol("prediction")
- .setOutputCol("predictedLabel")
- .setLabels(labelIndexer.labels)
-
-// Chain indexers and GBT in a Pipeline
-val pipeline = new Pipeline()
- .setStages(Array(labelIndexer, featureIndexer, gbt, labelConverter))
-
-// Train model. This also runs the indexers.
-val model = pipeline.fit(trainingData)
-
-// Make predictions.
-val predictions = model.transform(testData)
-
-// Select example rows to display.
-predictions.select("predictedLabel", "label", "features").show(5)
-
-// Select (prediction, true label) and compute test error
-val evaluator = new MulticlassClassificationEvaluator()
- .setLabelCol("indexedLabel")
- .setPredictionCol("prediction")
- .setMetricName("precision")
-val accuracy = evaluator.evaluate(predictions)
-println("Test Error = " + (1.0 - accuracy))
-
-val gbtModel = model.stages(2).asInstanceOf[GBTClassificationModel]
-println("Learned classification GBT model:\n" + gbtModel.toDebugString)
-{% endhighlight %}
+{% include_example scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala %}
Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/GBTClassifier.html) for more details.
-{% highlight java %}
-import org.apache.spark.ml.Pipeline;
-import org.apache.spark.ml.PipelineModel;
-import org.apache.spark.ml.PipelineStage;
-import org.apache.spark.ml.classification.GBTClassifier;
-import org.apache.spark.ml.classification.GBTClassificationModel;
-import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
-import org.apache.spark.ml.feature.*;
-import org.apache.spark.sql.DataFrame;
-
-// Load and parse the data file, converting it to a DataFrame.
-DataFrame data sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
-
-// Index labels, adding metadata to the label column.
-// Fit on whole dataset to include all labels in index.
-StringIndexerModel labelIndexer = new StringIndexer()
- .setInputCol("label")
- .setOutputCol("indexedLabel")
- .fit(data);
-// Automatically identify categorical features, and index them.
-// Set maxCategories so features with > 4 distinct values are treated as continuous.
-VectorIndexerModel featureIndexer = new VectorIndexer()
- .setInputCol("features")
- .setOutputCol("indexedFeatures")
- .setMaxCategories(4)
- .fit(data);
-
-// Split the data into training and test sets (30% held out for testing)
-DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3});
-DataFrame trainingData = splits[0];
-DataFrame testData = splits[1];
-
-// Train a GBT model.
-GBTClassifier gbt = new GBTClassifier()
- .setLabelCol("indexedLabel")
- .setFeaturesCol("indexedFeatures")
- .setMaxIter(10);
-
-// Convert indexed labels back to original labels.
-IndexToString labelConverter = new IndexToString()
- .setInputCol("prediction")
- .setOutputCol("predictedLabel")
- .setLabels(labelIndexer.labels());
-
-// Chain indexers and GBT in a Pipeline
-Pipeline pipeline = new Pipeline()
- .setStages(new PipelineStage[] {labelIndexer, featureIndexer, gbt, labelConverter});
-
-// Train model. This also runs the indexers.
-PipelineModel model = pipeline.fit(trainingData);
-
-// Make predictions.
-DataFrame predictions = model.transform(testData);
-
-// Select example rows to display.
-predictions.select("predictedLabel", "label", "features").show(5);
-
-// Select (prediction, true label) and compute test error
-MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
- .setLabelCol("indexedLabel")
- .setPredictionCol("prediction")
- .setMetricName("precision");
-double accuracy = evaluator.evaluate(predictions);
-System.out.println("Test Error = " + (1.0 - accuracy));
-
-GBTClassificationModel gbtModel =
- (GBTClassificationModel)(model.stages()[2]);
-System.out.println("Learned classification GBT model:\n" + gbtModel.toDebugString());
-{% endhighlight %}
+{% include_example java/org/apache/spark/examples/ml/JavaGradientBoostedTreeClassifierExample.java %}
Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classification.GBTClassifier) for more details.
-{% highlight python %}
-from pyspark.ml import Pipeline
-from pyspark.ml.classification import GBTClassifier
-from pyspark.ml.feature import StringIndexer, VectorIndexer
-from pyspark.ml.evaluation import MulticlassClassificationEvaluator
-
-# Load and parse the data file, converting it to a DataFrame.
-data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")
-
-# Index labels, adding metadata to the label column.
-# Fit on whole dataset to include all labels in index.
-labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data)
-# Automatically identify categorical features, and index them.
-# Set maxCategories so features with > 4 distinct values are treated as continuous.
-featureIndexer =\
- VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data)
-
-# Split the data into training and test sets (30% held out for testing)
-(trainingData, testData) = data.randomSplit([0.7, 0.3])
-
-# Train a GBT model.
-gbt = GBTClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures", maxIter=10)
-
-# Chain indexers and GBT in a Pipeline
-pipeline = Pipeline(stages=[labelIndexer, featureIndexer, gbt])
-
-# Train model. This also runs the indexers.
-model = pipeline.fit(trainingData)
-
-# Make predictions.
-predictions = model.transform(testData)
-
-# Select example rows to display.
-predictions.select("prediction", "indexedLabel", "features").show(5)
-
-# Select (prediction, true label) and compute test error
-evaluator = MulticlassClassificationEvaluator(
- labelCol="indexedLabel", predictionCol="prediction", metricName="precision")
-accuracy = evaluator.evaluate(predictions)
-print "Test Error = %g" % (1.0 - accuracy)
-
-gbtModel = model.stages[2]
-print gbtModel # summary only
-{% endhighlight %}
+{% include_example python/ml/gradient_boosted_tree_classifier_example.py %}
@@ -761,168 +269,21 @@ be true in general.
Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.regression.GBTRegressor) for more details.
-{% highlight scala %}
-import org.apache.spark.ml.Pipeline
-import org.apache.spark.ml.regression.GBTRegressor
-import org.apache.spark.ml.regression.GBTRegressionModel
-import org.apache.spark.ml.feature.VectorIndexer
-import org.apache.spark.ml.evaluation.RegressionEvaluator
-
-// Load and parse the data file, converting it to a DataFrame.
-val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")
-
-// Automatically identify categorical features, and index them.
-// Set maxCategories so features with > 4 distinct values are treated as continuous.
-val featureIndexer = new VectorIndexer()
- .setInputCol("features")
- .setOutputCol("indexedFeatures")
- .setMaxCategories(4)
- .fit(data)
-
-// Split the data into training and test sets (30% held out for testing)
-val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))
-
-// Train a GBT model.
-val gbt = new GBTRegressor()
- .setLabelCol("label")
- .setFeaturesCol("indexedFeatures")
- .setMaxIter(10)
-
-// Chain indexer and GBT in a Pipeline
-val pipeline = new Pipeline()
- .setStages(Array(featureIndexer, gbt))
-
-// Train model. This also runs the indexer.
-val model = pipeline.fit(trainingData)
-
-// Make predictions.
-val predictions = model.transform(testData)
-
-// Select example rows to display.
-predictions.select("prediction", "label", "features").show(5)
-
-// Select (prediction, true label) and compute test error
-val evaluator = new RegressionEvaluator()
- .setLabelCol("label")
- .setPredictionCol("prediction")
- .setMetricName("rmse")
-val rmse = evaluator.evaluate(predictions)
-println("Root Mean Squared Error (RMSE) on test data = " + rmse)
-
-val gbtModel = model.stages(1).asInstanceOf[GBTRegressionModel]
-println("Learned regression GBT model:\n" + gbtModel.toDebugString)
-{% endhighlight %}
+{% include_example scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala %}
Refer to the [Java API docs](api/java/org/apache/spark/ml/regression/GBTRegressor.html) for more details.
-{% highlight java %}
-import org.apache.spark.ml.Pipeline;
-import org.apache.spark.ml.PipelineModel;
-import org.apache.spark.ml.PipelineStage;
-import org.apache.spark.ml.evaluation.RegressionEvaluator;
-import org.apache.spark.ml.feature.VectorIndexer;
-import org.apache.spark.ml.feature.VectorIndexerModel;
-import org.apache.spark.ml.regression.GBTRegressionModel;
-import org.apache.spark.ml.regression.GBTRegressor;
-import org.apache.spark.sql.DataFrame;
-
-// Load and parse the data file, converting it to a DataFrame.
-DataFrame data = sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
-
-// Automatically identify categorical features, and index them.
-// Set maxCategories so features with > 4 distinct values are treated as continuous.
-VectorIndexerModel featureIndexer = new VectorIndexer()
- .setInputCol("features")
- .setOutputCol("indexedFeatures")
- .setMaxCategories(4)
- .fit(data);
-
-// Split the data into training and test sets (30% held out for testing)
-DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3});
-DataFrame trainingData = splits[0];
-DataFrame testData = splits[1];
-
-// Train a GBT model.
-GBTRegressor gbt = new GBTRegressor()
- .setLabelCol("label")
- .setFeaturesCol("indexedFeatures")
- .setMaxIter(10);
-
-// Chain indexer and GBT in a Pipeline
-Pipeline pipeline = new Pipeline()
- .setStages(new PipelineStage[] {featureIndexer, gbt});
-
-// Train model. This also runs the indexer.
-PipelineModel model = pipeline.fit(trainingData);
-
-// Make predictions.
-DataFrame predictions = model.transform(testData);
-
-// Select example rows to display.
-predictions.select("prediction", "label", "features").show(5);
-
-// Select (prediction, true label) and compute test error
-RegressionEvaluator evaluator = new RegressionEvaluator()
- .setLabelCol("label")
- .setPredictionCol("prediction")
- .setMetricName("rmse");
-double rmse = evaluator.evaluate(predictions);
-System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse);
-
-GBTRegressionModel gbtModel =
- (GBTRegressionModel)(model.stages()[1]);
-System.out.println("Learned regression GBT model:\n" + gbtModel.toDebugString());
-{% endhighlight %}
+{% include_example java/org/apache/spark/examples/ml/JavaGradientBoostedTreeRegressorExample.java %}
Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.regression.GBTRegressor) for more details.
-{% highlight python %}
-from pyspark.ml import Pipeline
-from pyspark.ml.regression import GBTRegressor
-from pyspark.ml.feature import VectorIndexer
-from pyspark.ml.evaluation import RegressionEvaluator
-
-# Load and parse the data file, converting it to a DataFrame.
-data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")
-
-# Automatically identify categorical features, and index them.
-# Set maxCategories so features with > 4 distinct values are treated as continuous.
-featureIndexer =\
- VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data)
-
-# Split the data into training and test sets (30% held out for testing)
-(trainingData, testData) = data.randomSplit([0.7, 0.3])
-
-# Train a GBT model.
-gbt = GBTRegressor(featuresCol="indexedFeatures", maxIter=10)
-
-# Chain indexer and GBT in a Pipeline
-pipeline = Pipeline(stages=[featureIndexer, gbt])
-
-# Train model. This also runs the indexer.
-model = pipeline.fit(trainingData)
-
-# Make predictions.
-predictions = model.transform(testData)
-
-# Select example rows to display.
-predictions.select("prediction", "label", "features").show(5)
-
-# Select (prediction, true label) and compute test error
-evaluator = RegressionEvaluator(
- labelCol="label", predictionCol="prediction", metricName="rmse")
-rmse = evaluator.evaluate(predictions)
-print "Root Mean Squared Error (RMSE) on test data = %g" % rmse
-
-gbtModel = model.stages[1]
-print gbtModel # summary only
-{% endhighlight %}
+{% include_example python/ml/gradient_boosted_tree_regressor_example.py %}
@@ -945,100 +306,13 @@ The example below demonstrates how to load the
Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classifier.OneVsRest) for more details.
-{% highlight scala %}
-import org.apache.spark.ml.classification.{LogisticRegression, OneVsRest}
-import org.apache.spark.mllib.evaluation.MulticlassMetrics
-import org.apache.spark.sql.{Row, SQLContext}
-
-val sqlContext = new SQLContext(sc)
-
-// parse data into dataframe
-val data = sqlContext.read.format("libsvm")
- .load("data/mllib/sample_multiclass_classification_data.txt")
-val Array(train, test) = data.randomSplit(Array(0.7, 0.3))
-
-// instantiate multiclass learner and train
-val ovr = new OneVsRest().setClassifier(new LogisticRegression)
-
-val ovrModel = ovr.fit(train)
-
-// score model on test data
-val predictions = ovrModel.transform(test).select("prediction", "label")
-val predictionsAndLabels = predictions.map {case Row(p: Double, l: Double) => (p, l)}
-
-// compute confusion matrix
-val metrics = new MulticlassMetrics(predictionsAndLabels)
-println(metrics.confusionMatrix)
-
-// the Iris DataSet has three classes
-val numClasses = 3
-
-println("label\tfpr\n")
-(0 until numClasses).foreach { index =>
- val label = index.toDouble
- println(label + "\t" + metrics.falsePositiveRate(label))
-}
-{% endhighlight %}
+{% include_example scala/org/apache/spark/examples/ml/OneVsRestExample.scala %}
Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/OneVsRest.html) for more details.
-{% highlight java %}
-import org.apache.spark.SparkConf;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.ml.classification.LogisticRegression;
-import org.apache.spark.ml.classification.OneVsRest;
-import org.apache.spark.ml.classification.OneVsRestModel;
-import org.apache.spark.mllib.evaluation.MulticlassMetrics;
-import org.apache.spark.mllib.linalg.Matrix;
-import org.apache.spark.sql.DataFrame;
-import org.apache.spark.sql.SQLContext;
-
-SparkConf conf = new SparkConf().setAppName("JavaOneVsRestExample");
-JavaSparkContext jsc = new JavaSparkContext(conf);
-SQLContext jsql = new SQLContext(jsc);
-
-DataFrame dataFrame = sqlContext.read().format("libsvm")
- .load("data/mllib/sample_multiclass_classification_data.txt");
-
-DataFrame[] splits = dataFrame.randomSplit(new double[] {0.7, 0.3}, 12345);
-DataFrame train = splits[0];
-DataFrame test = splits[1];
-
-// instantiate the One Vs Rest Classifier
-OneVsRest ovr = new OneVsRest().setClassifier(new LogisticRegression());
-
-// train the multiclass model
-OneVsRestModel ovrModel = ovr.fit(train.cache());
-
-// score the model on test data
-DataFrame predictions = ovrModel
- .transform(test)
- .select("prediction", "label");
-
-// obtain metrics
-MulticlassMetrics metrics = new MulticlassMetrics(predictions);
-Matrix confusionMatrix = metrics.confusionMatrix();
-
-// output the Confusion Matrix
-System.out.println("Confusion Matrix");
-System.out.println(confusionMatrix);
-
-// compute the false positive rate per label
-System.out.println();
-System.out.println("label\tfpr\n");
-
-// the Iris DataSet has three classes
-int numClasses = 3;
-for (int index = 0; index < numClasses; index++) {
- double label = (double) index;
- System.out.print(label);
- System.out.print("\t");
- System.out.print(metrics.falsePositiveRate(label));
- System.out.println();
-}
-{% endhighlight %}
+{% include_example java/org/apache/spark/examples/ml/JavaOneVsRestExample.java %}
diff --git a/docs/ml-linear-methods.md b/docs/ml-linear-methods.md
index 85edfd373465..0c13d7d0c82b 100644
--- a/docs/ml-linear-methods.md
+++ b/docs/ml-linear-methods.md
@@ -57,77 +57,15 @@ $\alpha$ and `regParam` corresponds to $\lambda$.
-{% highlight scala %}
-import org.apache.spark.ml.classification.LogisticRegression
-
-// Load training data
-val training = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")
-
-val lr = new LogisticRegression()
- .setMaxIter(10)
- .setRegParam(0.3)
- .setElasticNetParam(0.8)
-
-// Fit the model
-val lrModel = lr.fit(training)
-
-// Print the coefficients and intercept for logistic regression
-println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}")
-{% endhighlight %}
+{% include_example scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala %}
-{% highlight java %}
-import org.apache.spark.ml.classification.LogisticRegression;
-import org.apache.spark.ml.classification.LogisticRegressionModel;
-import org.apache.spark.SparkConf;
-import org.apache.spark.SparkContext;
-import org.apache.spark.sql.DataFrame;
-import org.apache.spark.sql.SQLContext;
-
-public class LogisticRegressionWithElasticNetExample {
- public static void main(String[] args) {
- SparkConf conf = new SparkConf()
- .setAppName("Logistic Regression with Elastic Net Example");
-
- SparkContext sc = new SparkContext(conf);
- SQLContext sql = new SQLContext(sc);
- String path = "data/mllib/sample_libsvm_data.txt";
-
- // Load training data
- DataFrame training = sqlContext.read().format("libsvm").load(path);
-
- LogisticRegression lr = new LogisticRegression()
- .setMaxIter(10)
- .setRegParam(0.3)
- .setElasticNetParam(0.8);
-
- // Fit the model
- LogisticRegressionModel lrModel = lr.fit(training);
-
- // Print the coefficients and intercept for logistic regression
- System.out.println("Coefficients: " + lrModel.coefficients() + " Intercept: " + lrModel.intercept());
- }
-}
-{% endhighlight %}
+{% include_example java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java %}
-{% highlight python %}
-from pyspark.ml.classification import LogisticRegression
-
-# Load training data
-training = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")
-
-lr = LogisticRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8)
-
-# Fit the model
-lrModel = lr.fit(training)
-
-# Print the coefficients and intercept for logistic regression
-print("Coefficients: " + str(lrModel.coefficients))
-print("Intercept: " + str(lrModel.intercept))
-{% endhighlight %}
+{% include_example python/ml/logistic_regression_with_elastic_net.py %}
@@ -152,33 +90,7 @@ This will likely change when multiclass classification is supported.
Continuing the earlier example:
-{% highlight scala %}
-import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary
-
-// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier example
-val trainingSummary = lrModel.summary
-
-// Obtain the objective per iteration.
-val objectiveHistory = trainingSummary.objectiveHistory
-objectiveHistory.foreach(loss => println(loss))
-
-// Obtain the metrics useful to judge performance on test data.
-// We cast the summary to a BinaryLogisticRegressionSummary since the problem is a
-// binary classification problem.
-val binarySummary = trainingSummary.asInstanceOf[BinaryLogisticRegressionSummary]
-
-// Obtain the receiver-operating characteristic as a dataframe and areaUnderROC.
-val roc = binarySummary.roc
-roc.show()
-println(binarySummary.areaUnderROC)
-
-// Set the model threshold to maximize F-Measure
-val fMeasure = binarySummary.fMeasureByThreshold
-val maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0)
-val bestThreshold = fMeasure.where($"F-Measure" === maxFMeasure).
- select("threshold").head().getDouble(0)
-lrModel.setThreshold(bestThreshold)
-{% endhighlight %}
+{% include_example scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala %}
@@ -192,39 +104,7 @@ This will likely change when multiclass classification is supported.
Continuing the earlier example:
-{% highlight java %}
-import org.apache.spark.ml.classification.LogisticRegressionTrainingSummary;
-import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary;
-import org.apache.spark.sql.functions;
-
-// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier example
-LogisticRegressionTrainingSummary trainingSummary = lrModel.summary();
-
-// Obtain the loss per iteration.
-double[] objectiveHistory = trainingSummary.objectiveHistory();
-for (double lossPerIteration : objectiveHistory) {
- System.out.println(lossPerIteration);
-}
-
-// Obtain the metrics useful to judge performance on test data.
-// We cast the summary to a BinaryLogisticRegressionSummary since the problem is a
-// binary classification problem.
-BinaryLogisticRegressionSummary binarySummary = (BinaryLogisticRegressionSummary) trainingSummary;
-
-// Obtain the receiver-operating characteristic as a dataframe and areaUnderROC.
-DataFrame roc = binarySummary.roc();
-roc.show();
-roc.select("FPR").show();
-System.out.println(binarySummary.areaUnderROC());
-
-// Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with
-// this selected threshold.
-DataFrame fMeasure = binarySummary.fMeasureByThreshold();
-double maxFMeasure = fMeasure.select(functions.max("F-Measure")).head().getDouble(0);
-double bestThreshold = fMeasure.where(fMeasure.col("F-Measure").equalTo(maxFMeasure)).
- select("threshold").head().getDouble(0);
-lrModel.setThreshold(bestThreshold);
-{% endhighlight %}
+{% include_example java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java %}
@@ -244,98 +124,16 @@ regression model and extracting model summary statistics.
-{% highlight scala %}
-import org.apache.spark.ml.regression.LinearRegression
-
-// Load training data
-val training = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")
-
-val lr = new LinearRegression()
- .setMaxIter(10)
- .setRegParam(0.3)
- .setElasticNetParam(0.8)
-
-// Fit the model
-val lrModel = lr.fit(training)
-
-// Print the coefficients and intercept for linear regression
-println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}")
-
-// Summarize the model over the training set and print out some metrics
-val trainingSummary = lrModel.summary
-println(s"numIterations: ${trainingSummary.totalIterations}")
-println(s"objectiveHistory: ${trainingSummary.objectiveHistory.toList}")
-trainingSummary.residuals.show()
-println(s"RMSE: ${trainingSummary.rootMeanSquaredError}")
-println(s"r2: ${trainingSummary.r2}")
-{% endhighlight %}
+{% include_example scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala %}
-{% highlight java %}
-import org.apache.spark.ml.regression.LinearRegression;
-import org.apache.spark.ml.regression.LinearRegressionModel;
-import org.apache.spark.ml.regression.LinearRegressionTrainingSummary;
-import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.SparkConf;
-import org.apache.spark.SparkContext;
-import org.apache.spark.sql.DataFrame;
-import org.apache.spark.sql.SQLContext;
-
-public class LinearRegressionWithElasticNetExample {
- public static void main(String[] args) {
- SparkConf conf = new SparkConf()
- .setAppName("Linear Regression with Elastic Net Example");
-
- SparkContext sc = new SparkContext(conf);
- SQLContext sql = new SQLContext(sc);
- String path = "data/mllib/sample_libsvm_data.txt";
-
- // Load training data
- DataFrame training = sqlContext.read().format("libsvm").load(path);
-
- LinearRegression lr = new LinearRegression()
- .setMaxIter(10)
- .setRegParam(0.3)
- .setElasticNetParam(0.8);
-
- // Fit the model
- LinearRegressionModel lrModel = lr.fit(training);
-
- // Print the coefficients and intercept for linear regression
- System.out.println("Coefficients: " + lrModel.coefficients() + " Intercept: " + lrModel.intercept());
-
- // Summarize the model over the training set and print out some metrics
- LinearRegressionTrainingSummary trainingSummary = lrModel.summary();
- System.out.println("numIterations: " + trainingSummary.totalIterations());
- System.out.println("objectiveHistory: " + Vectors.dense(trainingSummary.objectiveHistory()));
- trainingSummary.residuals().show();
- System.out.println("RMSE: " + trainingSummary.rootMeanSquaredError());
- System.out.println("r2: " + trainingSummary.r2());
- }
-}
-{% endhighlight %}
+{% include_example java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java %}
-{% highlight python %}
-from pyspark.ml.regression import LinearRegression
-
-# Load training data
-training = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")
-
-lr = LinearRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8)
-
-# Fit the model
-lrModel = lr.fit(training)
-
-# Print the coefficients and intercept for linear regression
-print("Coefficients: " + str(lrModel.coefficients))
-print("Intercept: " + str(lrModel.intercept))
-
-# Linear regression model summary is not yet supported in Python.
-{% endhighlight %}
+{% include_example python/ml/linear_regression_with_elastic_net.py %}
diff --git a/docs/mllib-evaluation-metrics.md b/docs/mllib-evaluation-metrics.md
index f73eff637dc3..6924037b941f 100644
--- a/docs/mllib-evaluation-metrics.md
+++ b/docs/mllib-evaluation-metrics.md
@@ -104,214 +104,21 @@ data, and evaluate the performance of the algorithm by several binary evaluation
Refer to the [`LogisticRegressionWithLBFGS` Scala docs](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS) and [`BinaryClassificationMetrics` Scala docs](api/scala/index.html#org.apache.spark.mllib.evaluation.BinaryClassificationMetrics) for details on the API.
-{% highlight scala %}
-import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
-import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
-import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.util.MLUtils
-
-// Load training data in LIBSVM format
-val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_binary_classification_data.txt")
-
-// Split data into training (60%) and test (40%)
-val Array(training, test) = data.randomSplit(Array(0.6, 0.4), seed = 11L)
-training.cache()
-
-// Run training algorithm to build the model
-val model = new LogisticRegressionWithLBFGS()
- .setNumClasses(2)
- .run(training)
-
-// Clear the prediction threshold so the model will return probabilities
-model.clearThreshold
-
-// Compute raw scores on the test set
-val predictionAndLabels = test.map { case LabeledPoint(label, features) =>
- val prediction = model.predict(features)
- (prediction, label)
-}
-
-// Instantiate metrics object
-val metrics = new BinaryClassificationMetrics(predictionAndLabels)
-
-// Precision by threshold
-val precision = metrics.precisionByThreshold
-precision.foreach { case (t, p) =>
- println(s"Threshold: $t, Precision: $p")
-}
-
-// Recall by threshold
-val recall = metrics.recallByThreshold
-recall.foreach { case (t, r) =>
- println(s"Threshold: $t, Recall: $r")
-}
-
-// Precision-Recall Curve
-val PRC = metrics.pr
-
-// F-measure
-val f1Score = metrics.fMeasureByThreshold
-f1Score.foreach { case (t, f) =>
- println(s"Threshold: $t, F-score: $f, Beta = 1")
-}
-
-val beta = 0.5
-val fScore = metrics.fMeasureByThreshold(beta)
-f1Score.foreach { case (t, f) =>
- println(s"Threshold: $t, F-score: $f, Beta = 0.5")
-}
-
-// AUPRC
-val auPRC = metrics.areaUnderPR
-println("Area under precision-recall curve = " + auPRC)
-
-// Compute thresholds used in ROC and PR curves
-val thresholds = precision.map(_._1)
-
-// ROC Curve
-val roc = metrics.roc
-
-// AUROC
-val auROC = metrics.areaUnderROC
-println("Area under ROC = " + auROC)
-
-{% endhighlight %}
+{% include_example scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala %}
Refer to the [`LogisticRegressionModel` Java docs](api/java/org/apache/spark/mllib/classification/LogisticRegressionModel.html) and [`LogisticRegressionWithLBFGS` Java docs](api/java/org/apache/spark/mllib/classification/LogisticRegressionWithLBFGS.html) for details on the API.
-{% highlight java %}
-import scala.Tuple2;
-
-import org.apache.spark.api.java.*;
-import org.apache.spark.rdd.RDD;
-import org.apache.spark.api.java.function.Function;
-import org.apache.spark.mllib.classification.LogisticRegressionModel;
-import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS;
-import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics;
-import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.mllib.util.MLUtils;
-import org.apache.spark.SparkConf;
-import org.apache.spark.SparkContext;
-
-public class BinaryClassification {
- public static void main(String[] args) {
- SparkConf conf = new SparkConf().setAppName("Binary Classification Metrics");
- SparkContext sc = new SparkContext(conf);
- String path = "data/mllib/sample_binary_classification_data.txt";
- JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD();
-
- // Split initial RDD into two... [60% training data, 40% testing data].
- JavaRDD[] splits = data.randomSplit(new double[] {0.6, 0.4}, 11L);
- JavaRDD training = splits[0].cache();
- JavaRDD test = splits[1];
-
- // Run training algorithm to build the model.
- final LogisticRegressionModel model = new LogisticRegressionWithLBFGS()
- .setNumClasses(2)
- .run(training.rdd());
-
- // Clear the prediction threshold so the model will return probabilities
- model.clearThreshold();
-
- // Compute raw scores on the test set.
- JavaRDD> predictionAndLabels = test.map(
- new Function>() {
- public Tuple2
Refer to the [`BinaryClassificationMetrics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.evaluation.BinaryClassificationMetrics) and [`LogisticRegressionWithLBFGS` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.classification.LogisticRegressionWithLBFGS) for more details on the API.
-{% highlight python %}
-from pyspark.mllib.classification import LogisticRegressionWithLBFGS
-from pyspark.mllib.evaluation import BinaryClassificationMetrics
-from pyspark.mllib.regression import LabeledPoint
-from pyspark.mllib.util import MLUtils
-
-# Several of the methods available in scala are currently missing from pyspark
-
-# Load training data in LIBSVM format
-data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_binary_classification_data.txt")
-
-# Split data into training (60%) and test (40%)
-training, test = data.randomSplit([0.6, 0.4], seed = 11L)
-training.cache()
-
-# Run training algorithm to build the model
-model = LogisticRegressionWithLBFGS.train(training)
-
-# Compute raw scores on the test set
-predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp.label))
-
-# Instantiate metrics object
-metrics = BinaryClassificationMetrics(predictionAndLabels)
-
-# Area under precision-recall curve
-print("Area under PR = %s" % metrics.areaUnderPR)
-
-# Area under ROC curve
-print("Area under ROC = %s" % metrics.areaUnderROC)
-
-{% endhighlight %}
-
+{% include_example python/mllib/binary_classification_metrics_example.py %}
@@ -433,204 +240,21 @@ the data, and evaluate the performance of the algorithm by several multiclass cl
Refer to the [`MulticlassMetrics` Scala docs](api/scala/index.html#org.apache.spark.mllib.evaluation.MulticlassMetrics) for details on the API.
-{% highlight scala %}
-import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
-import org.apache.spark.mllib.evaluation.MulticlassMetrics
-import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.util.MLUtils
-
-// Load training data in LIBSVM format
-val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt")
-
-// Split data into training (60%) and test (40%)
-val Array(training, test) = data.randomSplit(Array(0.6, 0.4), seed = 11L)
-training.cache()
-
-// Run training algorithm to build the model
-val model = new LogisticRegressionWithLBFGS()
- .setNumClasses(3)
- .run(training)
-
-// Compute raw scores on the test set
-val predictionAndLabels = test.map { case LabeledPoint(label, features) =>
- val prediction = model.predict(features)
- (prediction, label)
-}
-
-// Instantiate metrics object
-val metrics = new MulticlassMetrics(predictionAndLabels)
-
-// Confusion matrix
-println("Confusion matrix:")
-println(metrics.confusionMatrix)
-
-// Overall Statistics
-val precision = metrics.precision
-val recall = metrics.recall // same as true positive rate
-val f1Score = metrics.fMeasure
-println("Summary Statistics")
-println(s"Precision = $precision")
-println(s"Recall = $recall")
-println(s"F1 Score = $f1Score")
-
-// Precision by label
-val labels = metrics.labels
-labels.foreach { l =>
- println(s"Precision($l) = " + metrics.precision(l))
-}
-
-// Recall by label
-labels.foreach { l =>
- println(s"Recall($l) = " + metrics.recall(l))
-}
-
-// False positive rate by label
-labels.foreach { l =>
- println(s"FPR($l) = " + metrics.falsePositiveRate(l))
-}
-
-// F-measure by label
-labels.foreach { l =>
- println(s"F1-Score($l) = " + metrics.fMeasure(l))
-}
-
-// Weighted stats
-println(s"Weighted precision: ${metrics.weightedPrecision}")
-println(s"Weighted recall: ${metrics.weightedRecall}")
-println(s"Weighted F1 score: ${metrics.weightedFMeasure}")
-println(s"Weighted false positive rate: ${metrics.weightedFalsePositiveRate}")
-
-{% endhighlight %}
+{% include_example scala/org/apache/spark/examples/mllib/MulticlassMetricsExample.scala %}
Refer to the [`MulticlassMetrics` Java docs](api/java/org/apache/spark/mllib/evaluation/MulticlassMetrics.html) for details on the API.
-{% highlight java %}
-import scala.Tuple2;
-
-import org.apache.spark.api.java.*;
-import org.apache.spark.rdd.RDD;
-import org.apache.spark.api.java.function.Function;
-import org.apache.spark.mllib.classification.LogisticRegressionModel;
-import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS;
-import org.apache.spark.mllib.evaluation.MulticlassMetrics;
-import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.mllib.util.MLUtils;
-import org.apache.spark.mllib.linalg.Matrix;
-import org.apache.spark.SparkConf;
-import org.apache.spark.SparkContext;
-
-public class MulticlassClassification {
- public static void main(String[] args) {
- SparkConf conf = new SparkConf().setAppName("Multiclass Classification Metrics");
- SparkContext sc = new SparkContext(conf);
- String path = "data/mllib/sample_multiclass_classification_data.txt";
- JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD();
-
- // Split initial RDD into two... [60% training data, 40% testing data].
- JavaRDD[] splits = data.randomSplit(new double[] {0.6, 0.4}, 11L);
- JavaRDD training = splits[0].cache();
- JavaRDD test = splits[1];
-
- // Run training algorithm to build the model.
- final LogisticRegressionModel model = new LogisticRegressionWithLBFGS()
- .setNumClasses(3)
- .run(training.rdd());
-
- // Compute raw scores on the test set.
- JavaRDD> predictionAndLabels = test.map(
- new Function>() {
- public Tuple2 call(LabeledPoint p) {
- Double prediction = model.predict(p.features());
- return new Tuple2(prediction, p.label());
- }
- }
- );
-
- // Get evaluation metrics.
- MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd());
-
- // Confusion matrix
- Matrix confusion = metrics.confusionMatrix();
- System.out.println("Confusion matrix: \n" + confusion);
-
- // Overall statistics
- System.out.println("Precision = " + metrics.precision());
- System.out.println("Recall = " + metrics.recall());
- System.out.println("F1 Score = " + metrics.fMeasure());
-
- // Stats by labels
- for (int i = 0; i < metrics.labels().length; i++) {
- System.out.format("Class %f precision = %f\n", metrics.labels()[i], metrics.precision(metrics.labels()[i]));
- System.out.format("Class %f recall = %f\n", metrics.labels()[i], metrics.recall(metrics.labels()[i]));
- System.out.format("Class %f F1 score = %f\n", metrics.labels()[i], metrics.fMeasure(metrics.labels()[i]));
- }
-
- //Weighted stats
- System.out.format("Weighted precision = %f\n", metrics.weightedPrecision());
- System.out.format("Weighted recall = %f\n", metrics.weightedRecall());
- System.out.format("Weighted F1 score = %f\n", metrics.weightedFMeasure());
- System.out.format("Weighted false positive rate = %f\n", metrics.weightedFalsePositiveRate());
-
- // Save and load model
- model.save(sc, "myModelPath");
- LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "myModelPath");
- }
-}
-
-{% endhighlight %}
+ {% include_example java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java %}
Refer to the [`MulticlassMetrics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.evaluation.MulticlassMetrics) for more details on the API.
-{% highlight python %}
-from pyspark.mllib.classification import LogisticRegressionWithLBFGS
-from pyspark.mllib.util import MLUtils
-from pyspark.mllib.evaluation import MulticlassMetrics
-
-# Load training data in LIBSVM format
-data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt")
-
-# Split data into training (60%) and test (40%)
-training, test = data.randomSplit([0.6, 0.4], seed = 11L)
-training.cache()
-
-# Run training algorithm to build the model
-model = LogisticRegressionWithLBFGS.train(training, numClasses=3)
-
-# Compute raw scores on the test set
-predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp.label))
-
-# Instantiate metrics object
-metrics = MulticlassMetrics(predictionAndLabels)
-
-# Overall statistics
-precision = metrics.precision()
-recall = metrics.recall()
-f1Score = metrics.fMeasure()
-print("Summary Stats")
-print("Precision = %s" % precision)
-print("Recall = %s" % recall)
-print("F1 Score = %s" % f1Score)
-
-# Statistics by class
-labels = data.map(lambda lp: lp.label).distinct().collect()
-for label in sorted(labels):
- print("Class %s precision = %s" % (label, metrics.precision(label)))
- print("Class %s recall = %s" % (label, metrics.recall(label)))
- print("Class %s F1 Measure = %s" % (label, metrics.fMeasure(label, beta=1.0)))
-
-# Weighted stats
-print("Weighted recall = %s" % metrics.weightedRecall)
-print("Weighted precision = %s" % metrics.weightedPrecision)
-print("Weighted F(1) Score = %s" % metrics.weightedFMeasure())
-print("Weighted F(0.5) Score = %s" % metrics.weightedFMeasure(beta=0.5))
-print("Weighted false positive rate = %s" % metrics.weightedFalsePositiveRate)
-{% endhighlight %}
+{% include_example python/mllib/multi_class_metrics_example.py %}
@@ -1027,280 +518,21 @@ expanded world of non-positive weights are "the same as never having interacted
Refer to the [`RegressionMetrics` Scala docs](api/scala/index.html#org.apache.spark.mllib.evaluation.RegressionMetrics) and [`RankingMetrics` Scala docs](api/scala/index.html#org.apache.spark.mllib.evaluation.RankingMetrics) for details on the API.
-{% highlight scala %}
-import org.apache.spark.mllib.evaluation.{RegressionMetrics, RankingMetrics}
-import org.apache.spark.mllib.recommendation.{ALS, Rating}
-
-// Read in the ratings data
-val ratings = sc.textFile("data/mllib/sample_movielens_data.txt").map { line =>
- val fields = line.split("::")
- Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble - 2.5)
-}.cache()
-
-// Map ratings to 1 or 0, 1 indicating a movie that should be recommended
-val binarizedRatings = ratings.map(r => Rating(r.user, r.product, if (r.rating > 0) 1.0 else 0.0)).cache()
-
-// Summarize ratings
-val numRatings = ratings.count()
-val numUsers = ratings.map(_.user).distinct().count()
-val numMovies = ratings.map(_.product).distinct().count()
-println(s"Got $numRatings ratings from $numUsers users on $numMovies movies.")
-
-// Build the model
-val numIterations = 10
-val rank = 10
-val lambda = 0.01
-val model = ALS.train(ratings, rank, numIterations, lambda)
-
-// Define a function to scale ratings from 0 to 1
-def scaledRating(r: Rating): Rating = {
- val scaledRating = math.max(math.min(r.rating, 1.0), 0.0)
- Rating(r.user, r.product, scaledRating)
-}
-
-// Get sorted top ten predictions for each user and then scale from [0, 1]
-val userRecommended = model.recommendProductsForUsers(10).map{ case (user, recs) =>
- (user, recs.map(scaledRating))
-}
-
-// Assume that any movie a user rated 3 or higher (which maps to a 1) is a relevant document
-// Compare with top ten most relevant documents
-val userMovies = binarizedRatings.groupBy(_.user)
-val relevantDocuments = userMovies.join(userRecommended).map{ case (user, (actual, predictions)) =>
- (predictions.map(_.product), actual.filter(_.rating > 0.0).map(_.product).toArray)
-}
-
-// Instantiate metrics object
-val metrics = new RankingMetrics(relevantDocuments)
-
-// Precision at K
-Array(1, 3, 5).foreach{ k =>
- println(s"Precision at $k = ${metrics.precisionAt(k)}")
-}
-
-// Mean average precision
-println(s"Mean average precision = ${metrics.meanAveragePrecision}")
-
-// Normalized discounted cumulative gain
-Array(1, 3, 5).foreach{ k =>
- println(s"NDCG at $k = ${metrics.ndcgAt(k)}")
-}
-
-// Get predictions for each data point
-val allPredictions = model.predict(ratings.map(r => (r.user, r.product))).map(r => ((r.user, r.product), r.rating))
-val allRatings = ratings.map(r => ((r.user, r.product), r.rating))
-val predictionsAndLabels = allPredictions.join(allRatings).map{ case ((user, product), (predicted, actual)) =>
- (predicted, actual)
-}
-
-// Get the RMSE using regression metrics
-val regressionMetrics = new RegressionMetrics(predictionsAndLabels)
-println(s"RMSE = ${regressionMetrics.rootMeanSquaredError}")
-
-// R-squared
-println(s"R-squared = ${regressionMetrics.r2}")
-
-{% endhighlight %}
+{% include_example scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala %}
Refer to the [`RegressionMetrics` Java docs](api/java/org/apache/spark/mllib/evaluation/RegressionMetrics.html) and [`RankingMetrics` Java docs](api/java/org/apache/spark/mllib/evaluation/RankingMetrics.html) for details on the API.
-{% highlight java %}
-import scala.Tuple2;
-
-import org.apache.spark.api.java.*;
-import org.apache.spark.rdd.RDD;
-import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
-import org.apache.spark.SparkConf;
-import org.apache.spark.api.java.function.Function;
-import java.util.*;
-import org.apache.spark.mllib.evaluation.RegressionMetrics;
-import org.apache.spark.mllib.evaluation.RankingMetrics;
-import org.apache.spark.mllib.recommendation.ALS;
-import org.apache.spark.mllib.recommendation.Rating;
-
-// Read in the ratings data
-public class Ranking {
- public static void main(String[] args) {
- SparkConf conf = new SparkConf().setAppName("Ranking Metrics");
- JavaSparkContext sc = new JavaSparkContext(conf);
- String path = "data/mllib/sample_movielens_data.txt";
- JavaRDD data = sc.textFile(path);
- JavaRDD ratings = data.map(
- new Function() {
- public Rating call(String line) {
- String[] parts = line.split("::");
- return new Rating(Integer.parseInt(parts[0]), Integer.parseInt(parts[1]), Double.parseDouble(parts[2]) - 2.5);
- }
- }
- );
- ratings.cache();
-
- // Train an ALS model
- final MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), 10, 10, 0.01);
-
- // Get top 10 recommendations for every user and scale ratings from 0 to 1
- JavaRDD> userRecs = model.recommendProductsForUsers(10).toJavaRDD();
- JavaRDD> userRecsScaled = userRecs.map(
- new Function, Tuple2>() {
- public Tuple2 call(Tuple2 t) {
- Rating[] scaledRatings = new Rating[t._2().length];
- for (int i = 0; i < scaledRatings.length; i++) {
- double newRating = Math.max(Math.min(t._2()[i].rating(), 1.0), 0.0);
- scaledRatings[i] = new Rating(t._2()[i].user(), t._2()[i].product(), newRating);
- }
- return new Tuple2(t._1(), scaledRatings);
- }
- }
- );
- JavaPairRDD userRecommended = JavaPairRDD.fromJavaRDD(userRecsScaled);
-
- // Map ratings to 1 or 0, 1 indicating a movie that should be recommended
- JavaRDD binarizedRatings = ratings.map(
- new Function() {
- public Rating call(Rating r) {
- double binaryRating;
- if (r.rating() > 0.0) {
- binaryRating = 1.0;
- }
- else {
- binaryRating = 0.0;
- }
- return new Rating(r.user(), r.product(), binaryRating);
- }
- }
- );
-
- // Group ratings by common user
- JavaPairRDD> userMovies = binarizedRatings.groupBy(
- new Function() {
- public Object call(Rating r) {
- return r.user();
- }
- }
- );
-
- // Get true relevant documents from all user ratings
- JavaPairRDD> userMoviesList = userMovies.mapValues(
- new Function, List>() {
- public List call(Iterable docs) {
- List products = new ArrayList();
- for (Rating r : docs) {
- if (r.rating() > 0.0) {
- products.add(r.product());
- }
- }
- return products;
- }
- }
- );
-
- // Extract the product id from each recommendation
- JavaPairRDD> userRecommendedList = userRecommended.mapValues(
- new Function>() {
- public List call(Rating[] docs) {
- List products = new ArrayList();
- for (Rating r : docs) {
- products.add(r.product());
- }
- return products;
- }
- }
- );
- JavaRDD, List>> relevantDocs = userMoviesList.join(userRecommendedList).values();
-
- // Instantiate the metrics object
- RankingMetrics metrics = RankingMetrics.of(relevantDocs);
-
- // Precision and NDCG at k
- Integer[] kVector = {1, 3, 5};
- for (Integer k : kVector) {
- System.out.format("Precision at %d = %f\n", k, metrics.precisionAt(k));
- System.out.format("NDCG at %d = %f\n", k, metrics.ndcgAt(k));
- }
-
- // Mean average precision
- System.out.format("Mean average precision = %f\n", metrics.meanAveragePrecision());
-
- // Evaluate the model using numerical ratings and regression metrics
- JavaRDD> userProducts = ratings.map(
- new Function>() {
- public Tuple2 call(Rating r) {
- return new Tuple2(r.user(), r.product());
- }
- }
- );
- JavaPairRDD, Object> predictions = JavaPairRDD.fromJavaRDD(
- model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD().map(
- new Function, Object>>() {
- public Tuple2, Object> call(Rating r){
- return new Tuple2, Object>(
- new Tuple2(r.user(), r.product()), r.rating());
- }
- }
- ));
- JavaRDD> ratesAndPreds =
- JavaPairRDD.fromJavaRDD(ratings.map(
- new Function, Object>>() {
- public Tuple2, Object> call(Rating r){
- return new Tuple2, Object>(
- new Tuple2(r.user(), r.product()), r.rating());
- }
- }
- )).join(predictions).values();
-
- // Create regression metrics object
- RegressionMetrics regressionMetrics = new RegressionMetrics(ratesAndPreds.rdd());
-
- // Root mean squared error
- System.out.format("RMSE = %f\n", regressionMetrics.rootMeanSquaredError());
-
- // R-squared
- System.out.format("R-squared = %f\n", regressionMetrics.r2());
- }
-}
-
-{% endhighlight %}
+{% include_example java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java %}
Refer to the [`RegressionMetrics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.evaluation.RegressionMetrics) and [`RankingMetrics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.evaluation.RankingMetrics) for more details on the API.
-{% highlight python %}
-from pyspark.mllib.recommendation import ALS, Rating
-from pyspark.mllib.evaluation import RegressionMetrics, RankingMetrics
-
-# Read in the ratings data
-lines = sc.textFile("data/mllib/sample_movielens_data.txt")
-
-def parseLine(line):
- fields = line.split("::")
- return Rating(int(fields[0]), int(fields[1]), float(fields[2]) - 2.5)
-ratings = lines.map(lambda r: parseLine(r))
-
-# Train a model on to predict user-product ratings
-model = ALS.train(ratings, 10, 10, 0.01)
-
-# Get predicted ratings on all existing user-product pairs
-testData = ratings.map(lambda p: (p.user, p.product))
-predictions = model.predictAll(testData).map(lambda r: ((r.user, r.product), r.rating))
-
-ratingsTuple = ratings.map(lambda r: ((r.user, r.product), r.rating))
-scoreAndLabels = predictions.join(ratingsTuple).map(lambda tup: tup[1])
-
-# Instantiate regression metrics to compare predicted and actual ratings
-metrics = RegressionMetrics(scoreAndLabels)
-
-# Root mean sqaured error
-print("RMSE = %s" % metrics.rootMeanSquaredError)
-
-# R-squared
-print("R-squared = %s" % metrics.r2)
-
-{% endhighlight %}
+{% include_example python/mllib/ranking_metrics_example.py %}
@@ -1350,163 +582,21 @@ and evaluate the performance of the algorithm by several regression metrics.
Refer to the [`RegressionMetrics` Scala docs](api/scala/index.html#org.apache.spark.mllib.evaluation.RegressionMetrics) for details on the API.
-{% highlight scala %}
-import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.regression.LinearRegressionModel
-import org.apache.spark.mllib.regression.LinearRegressionWithSGD
-import org.apache.spark.mllib.linalg.Vectors
-import org.apache.spark.mllib.evaluation.RegressionMetrics
-import org.apache.spark.mllib.util.MLUtils
-
-// Load the data
-val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_linear_regression_data.txt").cache()
-
-// Build the model
-val numIterations = 100
-val model = LinearRegressionWithSGD.train(data, numIterations)
-
-// Get predictions
-val valuesAndPreds = data.map{ point =>
- val prediction = model.predict(point.features)
- (prediction, point.label)
-}
-
-// Instantiate metrics object
-val metrics = new RegressionMetrics(valuesAndPreds)
-
-// Squared error
-println(s"MSE = ${metrics.meanSquaredError}")
-println(s"RMSE = ${metrics.rootMeanSquaredError}")
-
-// R-squared
-println(s"R-squared = ${metrics.r2}")
-
-// Mean absolute error
-println(s"MAE = ${metrics.meanAbsoluteError}")
-
-// Explained variance
-println(s"Explained variance = ${metrics.explainedVariance}")
-
-{% endhighlight %}
+{% include_example scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala %}
Refer to the [`RegressionMetrics` Java docs](api/java/org/apache/spark/mllib/evaluation/RegressionMetrics.html) for details on the API.
-{% highlight java %}
-import scala.Tuple2;
-
-import org.apache.spark.api.java.*;
-import org.apache.spark.api.java.function.Function;
-import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.mllib.regression.LinearRegressionModel;
-import org.apache.spark.mllib.regression.LinearRegressionWithSGD;
-import org.apache.spark.mllib.evaluation.RegressionMetrics;
-import org.apache.spark.SparkConf;
-
-public class LinearRegression {
- public static void main(String[] args) {
- SparkConf conf = new SparkConf().setAppName("Linear Regression Example");
- JavaSparkContext sc = new JavaSparkContext(conf);
-
- // Load and parse the data
- String path = "data/mllib/sample_linear_regression_data.txt";
- JavaRDD data = sc.textFile(path);
- JavaRDD parsedData = data.map(
- new Function() {
- public LabeledPoint call(String line) {
- String[] parts = line.split(" ");
- double[] v = new double[parts.length - 1];
- for (int i = 1; i < parts.length - 1; i++)
- v[i - 1] = Double.parseDouble(parts[i].split(":")[1]);
- return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v));
- }
- }
- );
- parsedData.cache();
-
- // Building the model
- int numIterations = 100;
- final LinearRegressionModel model =
- LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), numIterations);
-
- // Evaluate model on training examples and compute training error
- JavaRDD> valuesAndPreds = parsedData.map(
- new Function>() {
- public Tuple2 call(LabeledPoint point) {
- double prediction = model.predict(point.features());
- return new Tuple2(prediction, point.label());
- }
- }
- );
-
- // Instantiate metrics object
- RegressionMetrics metrics = new RegressionMetrics(valuesAndPreds.rdd());
-
- // Squared error
- System.out.format("MSE = %f\n", metrics.meanSquaredError());
- System.out.format("RMSE = %f\n", metrics.rootMeanSquaredError());
-
- // R-squared
- System.out.format("R Squared = %f\n", metrics.r2());
-
- // Mean absolute error
- System.out.format("MAE = %f\n", metrics.meanAbsoluteError());
-
- // Explained variance
- System.out.format("Explained Variance = %f\n", metrics.explainedVariance());
-
- // Save and load model
- model.save(sc.sc(), "myModelPath");
- LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(), "myModelPath");
- }
-}
-
-{% endhighlight %}
+{% include_example java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java %}
Refer to the [`RegressionMetrics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.evaluation.RegressionMetrics) for more details on the API.
-{% highlight python %}
-from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD
-from pyspark.mllib.evaluation import RegressionMetrics
-from pyspark.mllib.linalg import DenseVector
-
-# Load and parse the data
-def parsePoint(line):
- values = line.split()
- return LabeledPoint(float(values[0]), DenseVector([float(x.split(':')[1]) for x in values[1:]]))
-
-data = sc.textFile("data/mllib/sample_linear_regression_data.txt")
-parsedData = data.map(parsePoint)
-
-# Build the model
-model = LinearRegressionWithSGD.train(parsedData)
-
-# Get predictions
-valuesAndPreds = parsedData.map(lambda p: (float(model.predict(p.features)), p.label))
-
-# Instantiate metrics object
-metrics = RegressionMetrics(valuesAndPreds)
-
-# Squared Error
-print("MSE = %s" % metrics.meanSquaredError)
-print("RMSE = %s" % metrics.rootMeanSquaredError)
-
-# R-squared
-print("R-squared = %s" % metrics.r2)
-
-# Mean absolute error
-print("MAE = %s" % metrics.meanAbsoluteError)
-
-# Explained variance
-print("Explained variance = %s" % metrics.explainedVariance)
-
-{% endhighlight %}
+{% include_example python/mllib/regression_metrics_example.py %}
diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md
index ec5a44d79212..a197d0e37302 100644
--- a/docs/running-on-mesos.md
+++ b/docs/running-on-mesos.md
@@ -161,21 +161,15 @@ Note that jars or python files that are passed to spark-submit should be URIs re
# Mesos Run Modes
-Spark can run over Mesos in two modes: "fine-grained" (default) and "coarse-grained".
+Spark can run over Mesos in two modes: "coarse-grained" (default) and "fine-grained".
-In "fine-grained" mode (default), each Spark task runs as a separate Mesos task. This allows
-multiple instances of Spark (and other frameworks) to share machines at a very fine granularity,
-where each application gets more or fewer machines as it ramps up and down, but it comes with an
-additional overhead in launching each task. This mode may be inappropriate for low-latency
-requirements like interactive queries or serving web requests.
-
-The "coarse-grained" mode will instead launch only *one* long-running Spark task on each Mesos
+The "coarse-grained" mode will launch only *one* long-running Spark task on each Mesos
machine, and dynamically schedule its own "mini-tasks" within it. The benefit is much lower startup
overhead, but at the cost of reserving the Mesos resources for the complete duration of the
application.
-To run in coarse-grained mode, set the `spark.mesos.coarse` property in your
-[SparkConf](configuration.html#spark-properties):
+Coarse-grained is the default mode. You can also set `spark.mesos.coarse` property to true
+to turn it on explictly in [SparkConf](configuration.html#spark-properties):
{% highlight scala %}
conf.set("spark.mesos.coarse", "true")
@@ -186,6 +180,19 @@ acquire. By default, it will acquire *all* cores in the cluster (that get offere
only makes sense if you run just one application at a time. You can cap the maximum number of cores
using `conf.set("spark.cores.max", "10")` (for example).
+In "fine-grained" mode, each Spark task runs as a separate Mesos task. This allows
+multiple instances of Spark (and other frameworks) to share machines at a very fine granularity,
+where each application gets more or fewer machines as it ramps up and down, but it comes with an
+additional overhead in launching each task. This mode may be inappropriate for low-latency
+requirements like interactive queries or serving web requests.
+
+To run in coarse-grained mode, set the `spark.mesos.coarse` property to false in your
+[SparkConf](configuration.html#spark-properties):
+
+{% highlight scala %}
+conf.set("spark.mesos.coarse", "false")
+{% endhighlight %}
+
You may also make use of `spark.mesos.constraints` to set attribute based constraints on mesos resource offers. By default, all resource offers will be accepted.
{% highlight scala %}
@@ -278,7 +285,7 @@ See the [configuration page](configuration.html) for information on Spark config
Set the name of the docker image that the Spark executors will run in. The selected
image must have Spark installed, as well as a compatible version of the Mesos library.
The installed path of Spark in the image can be specified with spark.mesos.executor.home;
- the installed path of the Mesos library can be specified with spark.executorEnv.MESOS_NATIVE_LIBRARY.
+ the installed path of the Mesos library can be specified with spark.executorEnv.MESOS_NATIVE_JAVA_LIBRARY.
diff --git a/docs/sparkr.md b/docs/sparkr.md
index 437bd4756c27..cfb9b41350f4 100644
--- a/docs/sparkr.md
+++ b/docs/sparkr.md
@@ -286,24 +286,37 @@ head(teenagers)
# Machine Learning
-SparkR allows the fitting of generalized linear models over DataFrames using the [glm()](api/R/glm.html) function. Under the hood, SparkR uses MLlib to train a model of the specified family. Currently the gaussian and binomial families are supported. We support a subset of the available R formula operators for model fitting, including '~', '.', '+', and '-'. The example below shows the use of building a gaussian GLM model using SparkR.
+SparkR allows the fitting of generalized linear models over DataFrames using the [glm()](api/R/glm.html) function. Under the hood, SparkR uses MLlib to train a model of the specified family. Currently the gaussian and binomial families are supported. We support a subset of the available R formula operators for model fitting, including '~', '.', ':', '+', and '-'.
+
+The [summary()](api/R/summary.html) function gives the summary of a model produced by [glm()](api/R/glm.html).
+
+* For gaussian GLM model, it returns a list with 'devianceResiduals' and 'coefficients' components. The 'devianceResiduals' gives the min/max deviance residuals of the estimation; the 'coefficients' gives the estimated coefficients and their estimated standard errors, t values and p-values. (It only available when model fitted by normal solver.)
+* For binomial GLM model, it returns a list with 'coefficients' component which gives the estimated coefficients.
+
+The examples below show the use of building gaussian GLM model and binomial GLM model using SparkR.
+
+## Gaussian GLM model
{% highlight r %}
# Create the DataFrame
df <- createDataFrame(sqlContext, iris)
-# Fit a linear model over the dataset.
+# Fit a gaussian GLM model over the dataset.
model <- glm(Sepal_Length ~ Sepal_Width + Species, data = df, family = "gaussian")
-# Model coefficients are returned in a similar format to R's native glm().
+# Model summary are returned in a similar format to R's native glm().
summary(model)
+##$devianceResiduals
+## Min Max
+## -1.307112 1.412532
+##
##$coefficients
-## Estimate
-##(Intercept) 2.2513930
-##Sepal_Width 0.8035609
-##Species_versicolor 1.4587432
-##Species_virginica 1.9468169
+## Estimate Std. Error t value Pr(>|t|)
+##(Intercept) 2.251393 0.3697543 6.08889 9.568102e-09
+##Sepal_Width 0.8035609 0.106339 7.556598 4.187317e-12
+##Species_versicolor 1.458743 0.1121079 13.01195 0
+##Species_virginica 1.946817 0.100015 19.46525 0
# Make predictions based on the model.
predictions <- predict(model, newData = df)
@@ -317,3 +330,59 @@ head(select(predictions, "Sepal_Length", "prediction"))
##6 5.4 5.385281
{% endhighlight %}
+
+## Binomial GLM model
+
+
+{% highlight r %}
+# Create the DataFrame
+df <- createDataFrame(sqlContext, iris)
+training <- filter(df, df$Species != "setosa")
+
+# Fit a binomial GLM model over the dataset.
+model <- glm(Species ~ Sepal_Length + Sepal_Width, data = training, family = "binomial")
+
+# Model coefficients are returned in a similar format to R's native glm().
+summary(model)
+##$coefficients
+## Estimate
+##(Intercept) -13.046005
+##Sepal_Length 1.902373
+##Sepal_Width 0.404655
+{% endhighlight %}
+
+
+# R Function Name Conflicts
+
+When loading and attaching a new package in R, it is possible to have a name [conflict](https://stat.ethz.ch/R-manual/R-devel/library/base/html/library.html), where a
+function is masking another function.
+
+The following functions are masked by the SparkR package:
+
+
+
Masked function
How to Access
+
+
cov in package:stats
+
stats::cov(x, y = NULL, use = "everything",
+ method = c("pearson", "kendall", "spearman"))
+
+You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-manual/R-devel/library/base/html/search.html)
+
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index 6e02d6564b00..e347754055e7 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -2051,6 +2051,20 @@ options.
# Migration Guide
+## Upgrading From Spark SQL 1.5 to 1.6
+
+ - From Spark 1.6, by default the Thrift server runs in multi-session mode. Which means each JDBC/ODBC
+ connection owns a copy of their own SQL configuration and temporary function registry. Cached
+ tables are still shared though. If you prefer to run the Thrift server in the old single-session
+ mode, please set option `spark.sql.hive.thriftServer.singleSession` to `true`. You may either add
+ this option to `spark-defaults.conf`, or pass it to `start-thriftserver.sh` via `--conf`:
+
+ {% highlight bash %}
+ ./sbin/start-thriftserver.sh \
+ --conf spark.sql.hive.thriftServer.singleSession=true \
+ ...
+ {% endhighlight %}
+
## Upgrading From Spark SQL 1.4 to 1.5
- Optimized execution using manually managed memory (Tungsten) is now enabled by default, along with
diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md
index e9a27f446a89..96b36b7a7320 100644
--- a/docs/streaming-programming-guide.md
+++ b/docs/streaming-programming-guide.md
@@ -2001,8 +2001,7 @@ If the number of tasks launched per second is high (say, 50 or more per second),
of sending out tasks to the slaves may be significant and will make it hard to achieve sub-second
latencies. The overhead can be reduced by the following changes:
-* **Task Serialization**: Using Kryo serialization for serializing tasks can reduce the task
- sizes, and therefore reduce the time taken to send them to the slaves.
+* **Task Serialization**: Using Kryo serialization for serializing tasks can reduce the task sizes, and therefore reduce the time taken to send them to the slaves. This is controlled by the ```spark.closure.serializer``` property. However, at this time, Kryo serialization cannot be enabled for closure serialization. This may be resolved in a future release.
* **Execution mode**: Running Spark in Standalone mode or coarse-grained Mesos mode leads to
better task launch times than the fine-grained Mesos mode. Please refer to the
diff --git a/docs/tuning.md b/docs/tuning.md
index 879340a01544..a8fe7a453279 100644
--- a/docs/tuning.md
+++ b/docs/tuning.md
@@ -88,9 +88,39 @@ than the "raw" data inside their fields. This is due to several reasons:
but also pointers (typically 8 bytes each) to the next object in the list.
* Collections of primitive types often store them as "boxed" objects such as `java.lang.Integer`.
-This section will discuss how to determine the memory usage of your objects, and how to improve
-it -- either by changing your data structures, or by storing data in a serialized format.
-We will then cover tuning Spark's cache size and the Java garbage collector.
+This section will start with an overview of memory management in Spark, then discuss specific
+strategies the user can take to make more efficient use of memory in his/her application. In
+particular, we will describe how to determine the memory usage of your objects, and how to
+improve it -- either by changing your data structures, or by storing data in a serialized
+format. We will then cover tuning Spark's cache size and the Java garbage collector.
+
+## Memory Management Overview
+
+Memory usage in Spark largely falls under one of two categories: execution and storage.
+Execution memory refers to that used for computation in shuffles, joins, sorts and aggregations,
+while storage memory refers to that used for caching and propagating internal data across the
+cluster. In Spark, execution and storage share a unified region (M). When no execution memory is
+used, storage can acquire all the available memory and vice versa. Execution may evict storage
+if necessary, but only until total storage memory usage falls under a certain threshold (R).
+In other words, `R` describes a subregion within `M` where cached blocks are never evicted.
+Storage may not evict execution due to complexities in implementation.
+
+This design ensures several desirable properties. First, applications that do not use caching
+can use the entire space for execution, obviating unnecessary disk spills. Second, applications
+that do use caching can reserve a minimum storage space (R) where their data blocks are immune
+to being evicted. Lastly, this approach provides reasonable out-of-the-box performance for a
+variety of workloads without requiring user expertise of how memory is divided internally.
+
+Although there are two relevant configurations, the typical user should not need to adjust them
+as the default values are applicable to most workloads:
+
+* `spark.memory.fraction` expresses the size of `M` as a fraction of the total JVM heap space
+(default 0.75). The rest of the space (25%) is reserved for user data structures, internal
+metadata in Spark, and safeguarding against OOM errors in the case of sparse and unusually
+large records.
+* `spark.memory.storageFraction` expresses the size of `R` as a fraction of `M` (default 0.5).
+`R` is the storage space within `M` where cached blocks immune to being evicted by execution.
+
## Determining Memory Consumption
@@ -151,18 +181,6 @@ time spent GC. This can be done by adding `-verbose:gc -XX:+PrintGCDetails -XX:+
each time a garbage collection occurs. Note these logs will be on your cluster's worker nodes (in the `stdout` files in
their work directories), *not* on your driver program.
-**Cache Size Tuning**
-
-One important configuration parameter for GC is the amount of memory that should be used for caching RDDs.
-By default, Spark uses 60% of the configured executor memory (`spark.executor.memory`) to
-cache RDDs. This means that 40% of memory is available for any objects created during task execution.
-
-In case your tasks slow down and you find that your JVM is garbage-collecting frequently or running out of
-memory, lowering this value will help reduce the memory consumption. To change this to, say, 50%, you can call
-`conf.set("spark.storage.memoryFraction", "0.5")` on your SparkConf. Combined with the use of serialized caching,
-using a smaller cache should be sufficient to mitigate most of the garbage collection problems.
-In case you are interested in further tuning the Java GC, continue reading below.
-
**Advanced GC Tuning**
To further tune garbage collection, we first need to understand some basic information about memory management in the JVM:
@@ -183,9 +201,9 @@ temporary objects created during task execution. Some steps which may be useful
* Check if there are too many garbage collections by collecting GC stats. If a full GC is invoked multiple times for
before a task completes, it means that there isn't enough memory available for executing tasks.
-* In the GC stats that are printed, if the OldGen is close to being full, reduce the amount of memory used for caching.
- This can be done using the `spark.storage.memoryFraction` property. It is better to cache fewer objects than to slow
- down task execution!
+* In the GC stats that are printed, if the OldGen is close to being full, reduce the amount of
+ memory used for caching by lowering `spark.memory.storageFraction`; it is better to cache fewer
+ objects than to slow down task execution!
* If there are too many minor collections but not many major GCs, allocating more memory for Eden would help. You
can set the size of the Eden to be an over-estimate of how much memory each task will need. If the size of Eden
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeClassifierExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeClassifierExample.java
new file mode 100644
index 000000000000..848fe6566c1e
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeClassifierExample.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaSparkContext;
+// $example on$
+import org.apache.spark.ml.Pipeline;
+import org.apache.spark.ml.PipelineModel;
+import org.apache.spark.ml.PipelineStage;
+import org.apache.spark.ml.classification.GBTClassificationModel;
+import org.apache.spark.ml.classification.GBTClassifier;
+import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
+import org.apache.spark.ml.feature.*;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.SQLContext;
+// $example off$
+
+public class JavaGradientBoostedTreeClassifierExample {
+ public static void main(String[] args) {
+ SparkConf conf = new SparkConf().setAppName("JavaGradientBoostedTreeClassifierExample");
+ JavaSparkContext jsc = new JavaSparkContext(conf);
+ SQLContext sqlContext = new SQLContext(jsc);
+
+ // $example on$
+ // Load and parse the data file, converting it to a DataFrame.
+ DataFrame data = sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
+
+ // Index labels, adding metadata to the label column.
+ // Fit on whole dataset to include all labels in index.
+ StringIndexerModel labelIndexer = new StringIndexer()
+ .setInputCol("label")
+ .setOutputCol("indexedLabel")
+ .fit(data);
+ // Automatically identify categorical features, and index them.
+ // Set maxCategories so features with > 4 distinct values are treated as continuous.
+ VectorIndexerModel featureIndexer = new VectorIndexer()
+ .setInputCol("features")
+ .setOutputCol("indexedFeatures")
+ .setMaxCategories(4)
+ .fit(data);
+
+ // Split the data into training and test sets (30% held out for testing)
+ DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3});
+ DataFrame trainingData = splits[0];
+ DataFrame testData = splits[1];
+
+ // Train a GBT model.
+ GBTClassifier gbt = new GBTClassifier()
+ .setLabelCol("indexedLabel")
+ .setFeaturesCol("indexedFeatures")
+ .setMaxIter(10);
+
+ // Convert indexed labels back to original labels.
+ IndexToString labelConverter = new IndexToString()
+ .setInputCol("prediction")
+ .setOutputCol("predictedLabel")
+ .setLabels(labelIndexer.labels());
+
+ // Chain indexers and GBT in a Pipeline
+ Pipeline pipeline = new Pipeline()
+ .setStages(new PipelineStage[] {labelIndexer, featureIndexer, gbt, labelConverter});
+
+ // Train model. This also runs the indexers.
+ PipelineModel model = pipeline.fit(trainingData);
+
+ // Make predictions.
+ DataFrame predictions = model.transform(testData);
+
+ // Select example rows to display.
+ predictions.select("predictedLabel", "label", "features").show(5);
+
+ // Select (prediction, true label) and compute test error
+ MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
+ .setLabelCol("indexedLabel")
+ .setPredictionCol("prediction")
+ .setMetricName("precision");
+ double accuracy = evaluator.evaluate(predictions);
+ System.out.println("Test Error = " + (1.0 - accuracy));
+
+ GBTClassificationModel gbtModel = (GBTClassificationModel)(model.stages()[2]);
+ System.out.println("Learned classification GBT model:\n" + gbtModel.toDebugString());
+ // $example off$
+
+ jsc.stop();
+ }
+}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeRegressorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeRegressorExample.java
new file mode 100644
index 000000000000..1f67b0842db0
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeRegressorExample.java
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaSparkContext;
+// $example on$
+import org.apache.spark.ml.Pipeline;
+import org.apache.spark.ml.PipelineModel;
+import org.apache.spark.ml.PipelineStage;
+import org.apache.spark.ml.evaluation.RegressionEvaluator;
+import org.apache.spark.ml.feature.VectorIndexer;
+import org.apache.spark.ml.feature.VectorIndexerModel;
+import org.apache.spark.ml.regression.GBTRegressionModel;
+import org.apache.spark.ml.regression.GBTRegressor;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.SQLContext;
+// $example off$
+
+public class JavaGradientBoostedTreeRegressorExample {
+ public static void main(String[] args) {
+ SparkConf conf = new SparkConf().setAppName("JavaGradientBoostedTreeRegressorExample");
+ JavaSparkContext jsc = new JavaSparkContext(conf);
+ SQLContext sqlContext = new SQLContext(jsc);
+
+ // $example on$
+ // Load and parse the data file, converting it to a DataFrame.
+ DataFrame data = sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
+
+ // Automatically identify categorical features, and index them.
+ // Set maxCategories so features with > 4 distinct values are treated as continuous.
+ VectorIndexerModel featureIndexer = new VectorIndexer()
+ .setInputCol("features")
+ .setOutputCol("indexedFeatures")
+ .setMaxCategories(4)
+ .fit(data);
+
+ // Split the data into training and test sets (30% held out for testing)
+ DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3});
+ DataFrame trainingData = splits[0];
+ DataFrame testData = splits[1];
+
+ // Train a GBT model.
+ GBTRegressor gbt = new GBTRegressor()
+ .setLabelCol("label")
+ .setFeaturesCol("indexedFeatures")
+ .setMaxIter(10);
+
+ // Chain indexer and GBT in a Pipeline
+ Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] {featureIndexer, gbt});
+
+ // Train model. This also runs the indexer.
+ PipelineModel model = pipeline.fit(trainingData);
+
+ // Make predictions.
+ DataFrame predictions = model.transform(testData);
+
+ // Select example rows to display.
+ predictions.select("prediction", "label", "features").show(5);
+
+ // Select (prediction, true label) and compute test error
+ RegressionEvaluator evaluator = new RegressionEvaluator()
+ .setLabelCol("label")
+ .setPredictionCol("prediction")
+ .setMetricName("rmse");
+ double rmse = evaluator.evaluate(predictions);
+ System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse);
+
+ GBTRegressionModel gbtModel = (GBTRegressionModel)(model.stages()[1]);
+ System.out.println("Learned regression GBT model:\n" + gbtModel.toDebugString());
+ // $example off$
+
+ jsc.stop();
+ }
+}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java
index be2bf0c7b465..47665ff2b1f3 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java
@@ -41,7 +41,7 @@
* An example demonstrating a k-means clustering.
* Run with
*
+ }
+ }
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala
index 731a369fc92c..7f9e2c973497 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala
@@ -67,7 +67,7 @@ private[streaming] object WriteAheadLogUtils extends Logging {
}
def isBatchingEnabled(conf: SparkConf, isDriver: Boolean): Boolean = {
- isDriver && conf.getBoolean(DRIVER_WAL_BATCHING_CONF_KEY, defaultValue = false)
+ isDriver && conf.getBoolean(DRIVER_WAL_BATCHING_CONF_KEY, defaultValue = true)
}
/**
diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
index c5217149224e..609bb4413b6b 100644
--- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
+++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
@@ -37,7 +37,9 @@
import com.google.common.io.Files;
import com.google.common.collect.Sets;
+import org.apache.spark.Accumulator;
import org.apache.spark.HashPartitioner;
+import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
@@ -45,7 +47,6 @@
import org.apache.spark.storage.StorageLevel;
import org.apache.spark.streaming.api.java.*;
import org.apache.spark.util.Utils;
-import org.apache.spark.SparkConf;
// The test suite itself is Serializable so that anonymous Function implementations can be
// serialized, as an alternative to converting these anonymous classes to static inner classes;
@@ -768,6 +769,44 @@ public Iterable call(String x) {
assertOrderInvariantEquals(expected, result);
}
+ @SuppressWarnings("unchecked")
+ @Test
+ public void testForeachRDD() {
+ final Accumulator accumRdd = ssc.sc().accumulator(0);
+ final Accumulator accumEle = ssc.sc().accumulator(0);
+ List> inputData = Arrays.asList(
+ Arrays.asList(1,1,1),
+ Arrays.asList(1,1,1));
+
+ JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
+ JavaTestUtils.attachTestOutputStream(stream.count()); // dummy output
+
+ stream.foreachRDD(new VoidFunction>() {
+ @Override
+ public void call(JavaRDD rdd) {
+ accumRdd.add(1);
+ rdd.foreach(new VoidFunction() {
+ @Override
+ public void call(Integer i) {
+ accumEle.add(1);
+ }
+ });
+ }
+ });
+
+ // This is a test to make sure foreachRDD(VoidFunction2) can be called from Java
+ stream.foreachRDD(new VoidFunction2, Time>() {
+ @Override
+ public void call(JavaRDD rdd, Time time) {
+ }
+ });
+
+ JavaTestUtils.runStreams(ssc, 2, 2);
+
+ Assert.assertEquals(2, accumRdd.value().intValue());
+ Assert.assertEquals(6, accumEle.value().intValue());
+ }
+
@SuppressWarnings("unchecked")
@Test
public void testPairFlatMap() {
diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java
index ec2bffd6a5b9..7a8ef9d14784 100644
--- a/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java
+++ b/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java
@@ -23,6 +23,7 @@
import org.apache.spark.streaming.api.java.JavaStreamingContext;
import static org.junit.Assert.*;
+import com.google.common.io.Closeables;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
@@ -121,14 +122,19 @@ public void onStop() {
private void receive() {
try {
- Socket socket = new Socket(host, port);
- BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream()));
- String userInput;
- while ((userInput = in.readLine()) != null) {
- store(userInput);
+ Socket socket = null;
+ BufferedReader in = null;
+ try {
+ socket = new Socket(host, port);
+ in = new BufferedReader(new InputStreamReader(socket.getInputStream()));
+ String userInput;
+ while ((userInput = in.readLine()) != null) {
+ store(userInput);
+ }
+ } finally {
+ Closeables.close(in, /* swallowIOException = */ true);
+ Closeables.close(socket, /* swallowIOException = */ true);
}
- in.close();
- socket.close();
} catch(ConnectException ce) {
ce.printStackTrace();
restart("Could not connect", ce);
diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java
index 175b8a496b4e..09b5f8ed0327 100644
--- a/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java
+++ b/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java
@@ -108,6 +108,7 @@ public void close() {
public void testCustomWAL() {
SparkConf conf = new SparkConf();
conf.set("spark.streaming.driver.writeAheadLog.class", JavaWriteAheadLogSuite.class.getName());
+ conf.set("spark.streaming.driver.writeAheadLog.allowBatching", "false");
WriteAheadLog wal = WriteAheadLogUtils.createLogForDriver(conf, null, null);
String data1 = "data1";
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
index 84f5294aa39c..b1cbc7163bee 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
@@ -18,7 +18,6 @@
package org.apache.spark.streaming
import java.io.{ObjectOutputStream, ByteArrayOutputStream, ByteArrayInputStream, File}
-import org.apache.spark.TestUtils
import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
import scala.reflect.ClassTag
@@ -30,11 +29,13 @@ import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.io.{IntWritable, Text}
import org.apache.hadoop.mapred.TextOutputFormat
import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat}
+import org.mockito.Mockito.mock
import org.scalatest.concurrent.Eventually._
import org.scalatest.time.SpanSugar._
+import org.apache.spark.TestUtils
import org.apache.spark.streaming.dstream.{DStream, FileInputDStream}
-import org.apache.spark.streaming.scheduler.{ConstantEstimator, RateTestInputDStream, RateTestReceiver}
+import org.apache.spark.streaming.scheduler._
import org.apache.spark.util.{MutableURLClassLoader, Clock, ManualClock, Utils}
/**
@@ -611,6 +612,28 @@ class CheckpointSuite extends TestSuiteBase {
assert(ois.readObject().asInstanceOf[Class[_]].getName == "[LtestClz;")
}
+ test("SPARK-11267: the race condition of two checkpoints in a batch") {
+ val jobGenerator = mock(classOf[JobGenerator])
+ val checkpointDir = Utils.createTempDir().toString
+ val checkpointWriter =
+ new CheckpointWriter(jobGenerator, conf, checkpointDir, new Configuration())
+ val bytes1 = Array.fill[Byte](10)(1)
+ new checkpointWriter.CheckpointWriteHandler(
+ Time(2000), bytes1, clearCheckpointDataLater = false).run()
+ val bytes2 = Array.fill[Byte](10)(2)
+ new checkpointWriter.CheckpointWriteHandler(
+ Time(1000), bytes2, clearCheckpointDataLater = true).run()
+ val checkpointFiles = Checkpoint.getCheckpointFiles(checkpointDir).reverse.map { path =>
+ new File(path.toUri)
+ }
+ assert(checkpointFiles.size === 2)
+ // Although bytes2 was written with an old time, it contains the latest status, so we should
+ // try to read from it at first.
+ assert(Files.toByteArray(checkpointFiles(0)) === bytes2)
+ assert(Files.toByteArray(checkpointFiles(1)) === bytes1)
+ checkpointWriter.stop()
+ }
+
/**
* Tests a streaming operation under checkpointing, by restarting the operation
* from checkpoint file and verifying whether the final output is correct.
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala
index 7db17abb7947..081f5a1c93e6 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala
@@ -330,8 +330,13 @@ class ReceivedBlockTrackerSuite
: Seq[ReceivedBlockTrackerLogEvent] = {
logFiles.flatMap {
file => new FileBasedWriteAheadLogReader(file, hadoopConf).toSeq
- }.map { byteBuffer =>
- Utils.deserialize[ReceivedBlockTrackerLogEvent](byteBuffer.array)
+ }.flatMap { byteBuffer =>
+ val validBuffer = if (WriteAheadLogUtils.isBatchingEnabled(conf, isDriver = true)) {
+ Utils.deserialize[Array[Array[Byte]]](byteBuffer.array()).map(ByteBuffer.wrap)
+ } else {
+ Array(byteBuffer)
+ }
+ validBuffer.map(b => Utils.deserialize[ReceivedBlockTrackerLogEvent](b.array()))
}.toList
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala
index 5dc0472c7770..df4575ab25aa 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala
@@ -21,6 +21,7 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, SynchronizedBuffer, Synch
import scala.concurrent.Future
import scala.concurrent.ExecutionContext.Implicits.global
+import org.apache.spark.SparkException
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.streaming.receiver.Receiver
@@ -161,6 +162,14 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers {
}
}
+ test("don't call ssc.stop in listener") {
+ ssc = new StreamingContext("local[2]", "ssc", Milliseconds(1000))
+ val inputStream = ssc.receiverStream(new StreamingListenerSuiteReceiver)
+ inputStream.foreachRDD(_.count)
+
+ startStreamingContextAndCallStop(ssc)
+ }
+
test("onBatchCompleted with successful batch") {
ssc = new StreamingContext("local[2]", "test", Milliseconds(1000))
val inputStream = ssc.receiverStream(new StreamingListenerSuiteReceiver)
@@ -207,6 +216,17 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers {
assert(failureReasons(1).contains("This is another failed job"))
}
+ private def startStreamingContextAndCallStop(_ssc: StreamingContext): Unit = {
+ val contextStoppingCollector = new StreamingContextStoppingCollector(_ssc)
+ _ssc.addStreamingListener(contextStoppingCollector)
+ val batchCounter = new BatchCounter(_ssc)
+ _ssc.start()
+ // Make sure running at least one batch
+ batchCounter.waitUntilBatchesCompleted(expectedNumCompletedBatches = 1, timeout = 10000)
+ _ssc.stop()
+ assert(contextStoppingCollector.sparkExSeen)
+ }
+
private def startStreamingContextAndCollectFailureReasons(
_ssc: StreamingContext, isFailed: Boolean = false): Map[Int, String] = {
val failureReasonsCollector = new FailureReasonsCollector()
@@ -320,3 +340,17 @@ class FailureReasonsCollector extends StreamingListener {
}
}
}
+/**
+ * A StreamingListener that calls StreamingContext.stop().
+ */
+class StreamingContextStoppingCollector(val ssc: StreamingContext) extends StreamingListener {
+ @volatile var sparkExSeen = false
+ override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) {
+ try {
+ ssc.stop()
+ } catch {
+ case se: SparkException =>
+ sparkExSeen = true
+ }
+ }
+}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala
index e3072b444284..58aef74c0040 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala
@@ -22,9 +22,10 @@ import java.io.File
import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
import scala.reflect.ClassTag
+import org.scalatest.PrivateMethodTester._
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
-import org.apache.spark.streaming.dstream.{TrackStateDStream, TrackStateDStreamImpl}
+import org.apache.spark.streaming.dstream.{InternalTrackStateDStream, TrackStateDStream, TrackStateDStreamImpl}
import org.apache.spark.util.{ManualClock, Utils}
import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
@@ -57,6 +58,12 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef
sc = new SparkContext(conf)
}
+ override def afterAll(): Unit = {
+ if (sc != null) {
+ sc.stop()
+ }
+ }
+
test("state - get, exists, update, remove, ") {
var state: StateImpl[Int] = null
@@ -436,6 +443,41 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef
assert(collectedStateSnapshots.last.toSet === Set(("a", 1)))
}
+ test("trackStateByKey - checkpoint durations") {
+ val privateMethod = PrivateMethod[InternalTrackStateDStream[_, _, _, _]]('internalStream)
+
+ def testCheckpointDuration(
+ batchDuration: Duration,
+ expectedCheckpointDuration: Duration,
+ explicitCheckpointDuration: Option[Duration] = None
+ ): Unit = {
+ try {
+ ssc = new StreamingContext(sc, batchDuration)
+ val inputStream = new TestInputStream(ssc, Seq.empty[Seq[Int]], 2).map(_ -> 1)
+ val dummyFunc = (value: Option[Int], state: State[Int]) => 0
+ val trackStateStream = inputStream.trackStateByKey(StateSpec.function(dummyFunc))
+ val internalTrackStateStream = trackStateStream invokePrivate privateMethod()
+
+ explicitCheckpointDuration.foreach { d =>
+ trackStateStream.checkpoint(d)
+ }
+ trackStateStream.register()
+ ssc.start() // should initialize all the checkpoint durations
+ assert(trackStateStream.checkpointDuration === null)
+ assert(internalTrackStateStream.checkpointDuration === expectedCheckpointDuration)
+ } finally {
+ StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) }
+ }
+ }
+
+ testCheckpointDuration(Milliseconds(100), Seconds(1))
+ testCheckpointDuration(Seconds(1), Seconds(10))
+ testCheckpointDuration(Seconds(10), Seconds(100))
+
+ testCheckpointDuration(Milliseconds(100), Seconds(2), Some(Seconds(2)))
+ testCheckpointDuration(Seconds(1), Seconds(2), Some(Seconds(2)))
+ testCheckpointDuration(Seconds(10), Seconds(20), Some(Seconds(20)))
+ }
private def testOperation[K: ClassTag, S: ClassTag, T: ClassTag](
input: Seq[Seq[K]],
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
index 19ef5a14f8ab..0feb3af1abb0 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
@@ -17,31 +17,40 @@
package org.apache.spark.streaming.rdd
+import java.io.File
+
import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag
import org.scalatest.BeforeAndAfterAll
+import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.util.OpenHashMapBasedStateMap
-import org.apache.spark.streaming.{Time, State}
-import org.apache.spark.{HashPartitioner, SparkConf, SparkContext, SparkFunSuite}
+import org.apache.spark.streaming.{State, Time}
+import org.apache.spark.util.Utils
-class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
+class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with BeforeAndAfterAll {
private var sc: SparkContext = null
+ private var checkpointDir: File = _
override def beforeAll(): Unit = {
sc = new SparkContext(
new SparkConf().setMaster("local").setAppName("TrackStateRDDSuite"))
+ checkpointDir = Utils.createTempDir()
+ sc.setCheckpointDir(checkpointDir.toString)
}
override def afterAll(): Unit = {
if (sc != null) {
sc.stop()
}
+ Utils.deleteRecursively(checkpointDir)
}
+ override def sparkContext: SparkContext = sc
+
test("creation from pair RDD") {
val data = Seq((1, "1"), (2, "2"), (3, "3"))
val partitioner = new HashPartitioner(10)
@@ -278,6 +287,51 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
rdd7, Seq(("k3", 2)), Set())
}
+ test("checkpointing") {
+ /**
+ * This tests whether the TrackStateRDD correctly truncates any references to its parent RDDs -
+ * the data RDD and the parent TrackStateRDD.
+ */
+ def rddCollectFunc(rdd: RDD[TrackStateRDDRecord[Int, Int, Int]])
+ : Set[(List[(Int, Int, Long)], List[Int])] = {
+ rdd.map { record => (record.stateMap.getAll().toList, record.emittedRecords.toList) }
+ .collect.toSet
+ }
+
+ /** Generate TrackStateRDD with data RDD having a long lineage */
+ def makeStateRDDWithLongLineageDataRDD(longLineageRDD: RDD[Int])
+ : TrackStateRDD[Int, Int, Int, Int] = {
+ TrackStateRDD.createFromPairRDD(longLineageRDD.map { _ -> 1}, partitioner, Time(0))
+ }
+
+ testRDD(
+ makeStateRDDWithLongLineageDataRDD, reliableCheckpoint = true, rddCollectFunc _)
+ testRDDPartitions(
+ makeStateRDDWithLongLineageDataRDD, reliableCheckpoint = true, rddCollectFunc _)
+
+ /** Generate TrackStateRDD with parent state RDD having a long lineage */
+ def makeStateRDDWithLongLineageParenttateRDD(
+ longLineageRDD: RDD[Int]): TrackStateRDD[Int, Int, Int, Int] = {
+
+ // Create a TrackStateRDD that has a long lineage using the data RDD with a long lineage
+ val stateRDDWithLongLineage = makeStateRDDWithLongLineageDataRDD(longLineageRDD)
+
+ // Create a new TrackStateRDD, with the lineage lineage TrackStateRDD as the parent
+ new TrackStateRDD[Int, Int, Int, Int](
+ stateRDDWithLongLineage,
+ stateRDDWithLongLineage.sparkContext.emptyRDD[(Int, Int)].partitionBy(partitioner),
+ (time: Time, key: Int, value: Option[Int], state: State[Int]) => None,
+ Time(10),
+ None
+ )
+ }
+
+ testRDD(
+ makeStateRDDWithLongLineageParenttateRDD, reliableCheckpoint = true, rddCollectFunc _)
+ testRDDPartitions(
+ makeStateRDDWithLongLineageParenttateRDD, reliableCheckpoint = true, rddCollectFunc _)
+ }
+
/** Assert whether the `trackStateByKey` operation generates expected results */
private def assertOperation[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
testStateRDD: TrackStateRDD[K, V, S, T],
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala
index 2f11b255f110..92ad9fe52b77 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.streaming.receiver
import scala.collection.mutable
+import scala.language.reflectiveCalls
import org.scalatest.BeforeAndAfter
import org.scalatest.Matchers._
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
index 4273fd7dda8b..eaa88ea3cd38 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
@@ -20,7 +20,7 @@ import java.io._
import java.nio.ByteBuffer
import java.util.{Iterator => JIterator}
import java.util.concurrent.atomic.AtomicInteger
-import java.util.concurrent.{TimeUnit, CountDownLatch, ThreadPoolExecutor}
+import java.util.concurrent.{RejectedExecutionException, TimeUnit, CountDownLatch, ThreadPoolExecutor}
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
@@ -30,6 +30,7 @@ import scala.language.{implicitConversions, postfixOps}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
+import org.mockito.ArgumentCaptor
import org.mockito.Matchers.{eq => meq}
import org.mockito.Matchers._
import org.mockito.Mockito._
@@ -190,6 +191,28 @@ abstract class CommonWriteAheadLogTests(
}
assert(!nonexistentTempPath.exists(), "Directory created just by attempting to read segment")
}
+
+ test(testPrefix + "parallel recovery not enabled if closeFileAfterWrite = false") {
+ // write some data
+ val writtenData = (1 to 10).map { i =>
+ val data = generateRandomData()
+ val file = testDir + s"/log-$i-$i"
+ writeDataManually(data, file, allowBatching)
+ data
+ }.flatten
+
+ val wal = createWriteAheadLog(testDir, closeFileAfterWrite, allowBatching)
+ // create iterator but don't materialize it
+ val readData = wal.readAll().asScala.map(byteBufferToString)
+ wal.close()
+ if (closeFileAfterWrite) {
+ // the threadpool is shutdown by the wal.close call above, therefore we shouldn't be able
+ // to materialize the iterator with parallel recovery
+ intercept[RejectedExecutionException](readData.toArray)
+ } else {
+ assert(readData.toSeq === writtenData)
+ }
+ }
}
class FileBasedWriteAheadLogSuite
@@ -485,15 +508,18 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests(
}
blockingWal.allowWrite()
- val buffer1 = wrapArrayArrayByte(Array(event1))
- val buffer2 = wrapArrayArrayByte(Array(event2, event3, event4, event5))
+ val buffer = wrapArrayArrayByte(Array(event1))
+ val queuedEvents = Set(event2, event3, event4, event5)
eventually(timeout(1 second)) {
assert(batchedWal.invokePrivate(queueLength()) === 0)
- verify(wal, times(1)).write(meq(buffer1), meq(3L))
+ verify(wal, times(1)).write(meq(buffer), meq(3L))
// the file name should be the timestamp of the last record, as events should be naturally
// in order of timestamp, and we need the last element.
- verify(wal, times(1)).write(meq(buffer2), meq(10L))
+ val bufferCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer])
+ verify(wal, times(1)).write(bufferCaptor.capture(), meq(10L))
+ val records = BatchedWriteAheadLog.deaggregate(bufferCaptor.getValue).map(byteBufferToString)
+ assert(records.toSet === queuedEvents)
}
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogUtilsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogUtilsSuite.scala
index 9152728191ea..bfc5b0cf60fb 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogUtilsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogUtilsSuite.scala
@@ -56,19 +56,19 @@ class WriteAheadLogUtilsSuite extends SparkFunSuite {
test("log selection and creation") {
val emptyConf = new SparkConf() // no log configuration
- assertDriverLogClass[FileBasedWriteAheadLog](emptyConf)
+ assertDriverLogClass[FileBasedWriteAheadLog](emptyConf, isBatched = true)
assertReceiverLogClass[FileBasedWriteAheadLog](emptyConf)
// Verify setting driver WAL class
val driverWALConf = new SparkConf().set("spark.streaming.driver.writeAheadLog.class",
classOf[MockWriteAheadLog0].getName())
- assertDriverLogClass[MockWriteAheadLog0](driverWALConf)
+ assertDriverLogClass[MockWriteAheadLog0](driverWALConf, isBatched = true)
assertReceiverLogClass[FileBasedWriteAheadLog](driverWALConf)
// Verify setting receiver WAL class
val receiverWALConf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class",
classOf[MockWriteAheadLog0].getName())
- assertDriverLogClass[FileBasedWriteAheadLog](receiverWALConf)
+ assertDriverLogClass[FileBasedWriteAheadLog](receiverWALConf, isBatched = true)
assertReceiverLogClass[MockWriteAheadLog0](receiverWALConf)
// Verify setting receiver WAL class with 1-arg constructor
@@ -104,6 +104,19 @@ class WriteAheadLogUtilsSuite extends SparkFunSuite {
assertDriverLogClass[FileBasedWriteAheadLog](receiverWALConf, isBatched = true)
assertReceiverLogClass[MockWriteAheadLog0](receiverWALConf)
}
+
+ test("batching is enabled by default in WriteAheadLog") {
+ val conf = new SparkConf()
+ assert(WriteAheadLogUtils.isBatchingEnabled(conf, isDriver = true))
+ // batching is not valid for receiver WALs
+ assert(!WriteAheadLogUtils.isBatchingEnabled(conf, isDriver = false))
+ }
+
+ test("closeFileAfterWrite is disabled by default in WriteAheadLog") {
+ val conf = new SparkConf()
+ assert(!WriteAheadLogUtils.shouldCloseFileAfterWrite(conf, isDriver = true))
+ assert(!WriteAheadLogUtils.shouldCloseFileAfterWrite(conf, isDriver = false))
+ }
}
object WriteAheadLogUtilsSuite {
diff --git a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala
index a0524cabff2d..5155daa6d17b 100644
--- a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala
+++ b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala
@@ -72,7 +72,9 @@ object GenerateMIMAIgnore {
val classSymbol = mirror.classSymbol(Class.forName(className, false, classLoader))
val moduleSymbol = mirror.staticModule(className)
val directlyPrivateSpark =
- isPackagePrivate(classSymbol) || isPackagePrivateModule(moduleSymbol)
+ isPackagePrivate(classSymbol) ||
+ isPackagePrivateModule(moduleSymbol) ||
+ classSymbol.isPrivate
val developerApi = isDeveloperApi(classSymbol) || isDeveloperApi(moduleSymbol)
val experimental = isExperimental(classSymbol) || isExperimental(moduleSymbol)
/* Inner classes defined within a private[spark] class or object are effectively
diff --git a/unsafe/pom.xml b/unsafe/pom.xml
index caf1f77890b5..a1c1111364ee 100644
--- a/unsafe/pom.xml
+++ b/unsafe/pom.xml
@@ -36,6 +36,10 @@
+
+ com.twitter
+ chill_${scala.binary.version}
+
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index b7aecb5102ba..4bd3fd777207 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -24,6 +24,11 @@
import java.util.Arrays;
import java.util.Map;
+import com.esotericsoftware.kryo.Kryo;
+import com.esotericsoftware.kryo.KryoSerializable;
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.ByteArrayMethods;
@@ -38,9 +43,9 @@
*
* Note: This is not designed for general use cases, should not be used outside SQL.
*/
-public final class UTF8String implements Comparable, Externalizable {
+public final class UTF8String implements Comparable, Externalizable, KryoSerializable {
- // These are only updated by readExternal()
+ // These are only updated by readExternal() or read()
@Nonnull
private Object base;
private long offset;
@@ -1003,4 +1008,19 @@ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundExcept
in.readFully((byte[]) base);
}
+ @Override
+ public void write(Kryo kryo, Output out) {
+ byte[] bytes = getBytes();
+ out.writeInt(bytes.length);
+ out.write(bytes);
+ }
+
+ @Override
+ public void read(Kryo kryo, Input in) {
+ this.offset = BYTE_ARRAY_OFFSET;
+ this.numBytes = in.readInt();
+ this.base = new byte[numBytes];
+ in.read((byte[]) base);
+ }
+
}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index a3f33d80184a..ba799884f568 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -258,7 +258,8 @@ private[spark] class Client(
if (executorMem > maxMem) {
throw new IllegalArgumentException(s"Required executor memory (${args.executorMemory}" +
s"+$executorMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster! " +
- "Please increase the value of 'yarn.scheduler.maximum-allocation-mb'.")
+ "Please check the values of 'yarn.scheduler.maximum-allocation-mb' and/or " +
+ "'yarn.nodemanager.resource.memory-mb'.")
}
val amMem = args.amMemory + amMemoryOverhead
if (amMem > maxMem) {
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
index 4d9e777cb413..73cd9031f025 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
@@ -35,7 +35,7 @@ import org.apache.hadoop.yarn.util.RackResolver
import org.apache.log4j.{Level, Logger}
-import org.apache.spark.{Logging, SecurityManager, SparkConf}
+import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._
import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef}
import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason}
@@ -96,6 +96,10 @@ private[yarn] class YarnAllocator(
// was lost.
private val pendingLossReasonRequests = new HashMap[String, mutable.Buffer[RpcCallContext]]
+ // Maintain loss reasons for already released executors, it will be added when executor loss
+ // reason is got from AM-RM call, and be removed after querying this loss reason.
+ private val releasedExecutorLossReasons = new HashMap[String, ExecutorLossReason]
+
// Keep track of which container is running which executor to remove the executors later
// Visible for testing.
private[yarn] val executorIdToContainer = new HashMap[String, Container]
@@ -202,8 +206,7 @@ private[yarn] class YarnAllocator(
*/
def killExecutor(executorId: String): Unit = synchronized {
if (executorIdToContainer.contains(executorId)) {
- val container = executorIdToContainer.remove(executorId).get
- containerIdToExecutorId.remove(container.getId)
+ val container = executorIdToContainer.get(executorId).get
internalReleaseContainer(container)
numExecutorsRunning -= 1
} else {
@@ -478,7 +481,7 @@ private[yarn] class YarnAllocator(
(true, memLimitExceededLogMessage(
completedContainer.getDiagnostics,
PMEM_EXCEEDED_PATTERN))
- case unknown =>
+ case _ =>
numExecutorsFailed += 1
(true, "Container marked as failed: " + containerId + onHostStr +
". Exit status: " + completedContainer.getExitStatus +
@@ -490,7 +493,7 @@ private[yarn] class YarnAllocator(
} else {
logInfo(containerExitReason)
}
- ExecutorExited(0, exitCausedByApp, containerExitReason)
+ ExecutorExited(exitStatus, exitCausedByApp, containerExitReason)
} else {
// If we have already released this container, then it must mean
// that the driver has explicitly requested it to be killed
@@ -514,9 +517,18 @@ private[yarn] class YarnAllocator(
containerIdToExecutorId.remove(containerId).foreach { eid =>
executorIdToContainer.remove(eid)
- pendingLossReasonRequests.remove(eid).foreach { pendingRequests =>
- // Notify application of executor loss reasons so it can decide whether it should abort
- pendingRequests.foreach(_.reply(exitReason))
+ pendingLossReasonRequests.remove(eid) match {
+ case Some(pendingRequests) =>
+ // Notify application of executor loss reasons so it can decide whether it should abort
+ pendingRequests.foreach(_.reply(exitReason))
+
+ case None =>
+ // We cannot find executor for pending reasons. This is because completed container
+ // is processed before querying pending result. We should store it for later query.
+ // This is usually happened when explicitly killing a container, the result will be
+ // returned in one AM-RM communication. So query RPC will be later than this completed
+ // container process.
+ releasedExecutorLossReasons.put(eid, exitReason)
}
if (!alreadyReleased) {
// The executor could have gone away (like no route to host, node failure, etc)
@@ -538,8 +550,14 @@ private[yarn] class YarnAllocator(
if (executorIdToContainer.contains(eid)) {
pendingLossReasonRequests
.getOrElseUpdate(eid, new ArrayBuffer[RpcCallContext]) += context
+ } else if (releasedExecutorLossReasons.contains(eid)) {
+ // Executor is already released explicitly before getting the loss reason, so directly send
+ // the pre-stored lost reason
+ context.reply(releasedExecutorLossReasons.remove(eid).get)
} else {
logWarning(s"Tried to get the loss reason for non-existent executor $eid")
+ context.sendFailure(
+ new SparkException(s"Fail to find loss reason for non-existent executor $eid"))
}
}