diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 0e36a30c933d0..7aba4e5abeb5e 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -42,8 +42,10 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHad
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.api.conda.CondaEnvironment
+import org.apache.spark.api.conda.CondaEnvironment.CondaSetupInstructions
import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil}
+import org.apache.spark.deploy.{CondaRunner, LocalSparkCluster, SparkHadoopUtil}
import org.apache.spark.input.{FixedLengthBinaryInputFormat, PortableDataStream, StreamInputFormat, WholeTextFileInputFormat}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
@@ -336,6 +338,9 @@ class SparkContext(config: SparkConf) extends Logging {
override protected def initialValue(): Properties = new Properties()
}
+ // Retrieve the Conda Environment from CondaRunner if it has set one up for us
+ val condaEnvironment: Option[CondaEnvironment] = CondaRunner.condaEnvironment
+
/* ------------------------------------------------------------------------------------- *
| Initialization. This code initializes the context in a manner that is exception-safe. |
| All internal fields holding state are initialized here, and any error prompts the |
@@ -1851,6 +1856,28 @@ class SparkContext(config: SparkConf) extends Logging {
*/
def listJars(): Seq[String] = addedJars.keySet.toSeq
+ private[this] def condaEnvironmentOrFail(): CondaEnvironment = {
+ condaEnvironment.getOrElse(sys.error("A conda environment was not set up."))
+ }
+
+ /**
+ * Add a set of conda packages (identified by package match specification
+ * for all tasks to be executed on this SparkContext in the future.
+ */
+ def addCondaPackages(packages: Seq[String]): Unit = {
+ condaEnvironmentOrFail().installPackages(packages)
+ }
+
+ def addCondaChannel(url: String): Unit = {
+ condaEnvironmentOrFail().addChannel(url)
+ }
+
+ private[spark] def buildCondaInstructions(): Option[CondaSetupInstructions] = {
+ condaEnvironment.map(_.buildSetupInstructions)
+ }
+
+
/**
* When stopping SparkContext inside Spark components, it's easy to cause dead-lock since Spark
* may wait for some internal threads to finish. It's better to use this method to stop
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 539dbb55eeff0..cb4f751bc8fb7 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -26,6 +26,7 @@ import scala.util.Properties
import com.google.common.collect.MapMaker
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.api.conda.CondaEnvironment.CondaSetupInstructions
import org.apache.spark.api.python.PythonWorkerFactory
import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.internal.Logging
@@ -70,7 +71,10 @@ class SparkEnv (
val conf: SparkConf) extends Logging {
private[spark] var isStopped = false
- private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
+
+ case class PythonWorkerKey(pythonExec: Option[String], envVars: Map[String, String],
+ condaInstructions: Option[CondaSetupInstructions])
+ private val pythonWorkers = mutable.HashMap[PythonWorkerKey, PythonWorkerFactory]()
// A general, soft-reference map for metadata needed during HadoopRDD split computation
// (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats).
@@ -110,25 +114,29 @@ class SparkEnv (
}
private[spark]
- def createPythonWorker(pythonExec: String, envVars: Map[String, String]): java.net.Socket = {
+ def createPythonWorker(pythonExec: Option[String], envVars: Map[String, String],
+ condaInstructions: Option[CondaSetupInstructions]): java.net.Socket = {
synchronized {
- val key = (pythonExec, envVars)
- pythonWorkers.getOrElseUpdate(key, new PythonWorkerFactory(pythonExec, envVars)).create()
+ val key = PythonWorkerKey(pythonExec, envVars, condaInstructions)
+ pythonWorkers.getOrElseUpdate(key,
+ new PythonWorkerFactory(pythonExec, envVars, condaInstructions)).create()
}
}
private[spark]
- def destroyPythonWorker(pythonExec: String, envVars: Map[String, String], worker: Socket) {
+ def destroyPythonWorker(pythonExec: Option[String], envVars: Map[String, String],
+ condaInstructions: Option[CondaSetupInstructions], worker: Socket) {
synchronized {
- val key = (pythonExec, envVars)
+ val key = PythonWorkerKey(pythonExec, envVars, condaInstructions)
pythonWorkers.get(key).foreach(_.stopWorker(worker))
}
}
private[spark]
- def releasePythonWorker(pythonExec: String, envVars: Map[String, String], worker: Socket) {
+ def releasePythonWorker(pythonExec: Option[String], envVars: Map[String, String],
+ condaInstructions: Option[CondaSetupInstructions], worker: Socket) {
synchronized {
- val key = (pythonExec, envVars)
+ val key = PythonWorkerKey(pythonExec, envVars, condaInstructions)
pythonWorkers.get(key).foreach(_.releaseWorker(worker))
}
}
diff --git a/core/src/main/scala/org/apache/spark/api/conda/CondaEnvironment.scala b/core/src/main/scala/org/apache/spark/api/conda/CondaEnvironment.scala
new file mode 100644
index 0000000000000..b7cfa6b0e3659
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/conda/CondaEnvironment.scala
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.api.conda
+
+import java.io.File
+import java.nio.file.Path
+import java.util.{Map => JMap}
+
+import scala.collection.mutable
+
+import org.apache.spark.internal.Logging
+
+/**
+ * A stateful class that describes a Conda environment and also keeps track of packages that have
+ * been added, as well as additional channels.
+ *
+ * @param rootPath The root path under which envs/ and pkgs/ are located.
+ * @param envName The name of the environment.
+ */
+final class CondaEnvironment(val manager: CondaEnvironmentManager,
+ val rootPath: Path,
+ val envName: String,
+ bootstrapPackages: Seq[String],
+ bootstrapChannels: Seq[String]) extends Logging {
+
+ import CondaEnvironment._
+
+ private[this] val packages = mutable.Buffer(bootstrapPackages: _*)
+ private[this] val channels = bootstrapChannels.toBuffer
+
+ val condaEnvDir: Path = rootPath.resolve("envs").resolve(envName)
+
+ def activatedEnvironment(startEnv: Map[String, String] = Map.empty): Map[String, String] = {
+ require(!startEnv.contains("PATH"), "Defining PATH in a CondaEnvironment's startEnv is " +
+ s"prohibited; found PATH=${startEnv("PATH")}")
+ import collection.JavaConverters._
+ val newVars = System.getenv().asScala.toIterator ++ startEnv ++ List(
+ "CONDA_PREFIX" -> condaEnvDir.toString,
+ "CONDA_DEFAULT_ENV" -> condaEnvDir.toString,
+ "PATH" -> (condaEnvDir.resolve("bin").toString +
+ sys.env.get("PATH").map(File.pathSeparator + _).getOrElse(""))
+ )
+ newVars.toMap
+ }
+
+ def addChannel(url: String): Unit = {
+ channels += url
+ }
+
+ def installPackages(packages: Seq[String]): Unit = {
+ manager.runCondaProcess(rootPath,
+ List("install", "-n", envName, "-y", "--override-channels")
+ ::: channels.iterator.flatMap(Iterator("--channel", _)).toList
+ ::: "--" :: packages.toList,
+ description = s"install dependencies in conda env $condaEnvDir"
+ )
+
+ this.packages ++= packages
+ }
+
+ /**
+ * Clears the given java environment and replaces all variables with the environment
+ * produced after calling `activate` inside this conda environment.
+ */
+ def initializeJavaEnvironment(env: JMap[String, String]): Unit = {
+ env.clear()
+ val activatedEnv = activatedEnvironment()
+ activatedEnv.foreach { case (k, v) => env.put(k, v) }
+ logDebug(s"Initialised environment from conda: $activatedEnv")
+ }
+
+ /**
+ * This is for sending the instructions to the executors so they can replicate the same steps.
+ */
+ def buildSetupInstructions: CondaSetupInstructions = {
+ CondaSetupInstructions(packages.toList, channels.toList)
+ }
+}
+
+object CondaEnvironment {
+ case class CondaSetupInstructions(packages: Seq[String], channels: Seq[String]) {
+ require(channels.nonEmpty)
+ require(packages.nonEmpty)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/api/conda/CondaEnvironmentManager.scala b/core/src/main/scala/org/apache/spark/api/conda/CondaEnvironmentManager.scala
new file mode 100644
index 0000000000000..32d64e4a31a1f
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/conda/CondaEnvironmentManager.scala
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.api.conda
+
+import java.nio.file.Files
+import java.nio.file.Path
+import java.nio.file.Paths
+
+import scala.collection.JavaConverters._
+import scala.sys.process.BasicIO
+import scala.sys.process.Process
+import scala.sys.process.ProcessBuilder
+import scala.sys.process.ProcessIO
+
+import org.apache.spark.SparkConf
+import org.apache.spark.SparkException
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.CONDA_BINARY_PATH
+import org.apache.spark.internal.config.CONDA_CHANNEL_URLS
+import org.apache.spark.internal.config.CONDA_VERBOSITY
+import org.apache.spark.util.Utils
+
+final class CondaEnvironmentManager(condaBinaryPath: String, condaChannelUrls: Seq[String],
+ verbosity: Int = 0)
+ extends Logging {
+
+ require(condaChannelUrls.nonEmpty, "Can't have an empty list of conda channel URLs")
+ require(verbosity >= 0 && verbosity <= 3, "Verbosity must be between 0 and 3 inclusively")
+
+ def create(
+ baseDir: String,
+ bootstrapPackages: Seq[String]): CondaEnvironment = {
+ require(bootstrapPackages.nonEmpty, "Expected at least one bootstrap package.")
+ val name = "conda-env"
+
+ // must link in /tmp to reduce path length in case baseDir is very long...
+ // If baseDir path is too long, this breaks conda's 220 character limit for binary replacement.
+ // Don't even try to use java.io.tmpdir - yarn sets this to a very long path
+ val linkedBaseDir = Utils.createTempDir("/tmp", "conda").toPath.resolve("real")
+ logInfo(s"Creating symlink $linkedBaseDir -> $baseDir")
+ Files.createSymbolicLink(linkedBaseDir, Paths.get(baseDir))
+
+ val verbosityFlags = 0.until(verbosity).map(_ => "-v").toList
+
+ // Attempt to create environment
+ runCondaProcess(
+ linkedBaseDir,
+ List("create", "-n", name, "-y", "--override-channels", "--no-default-packages")
+ ::: verbosityFlags
+ ::: condaChannelUrls.flatMap(Iterator("--channel", _)).toList
+ ::: "--" :: bootstrapPackages.toList,
+ description = "create conda env"
+ )
+
+ new CondaEnvironment(this, linkedBaseDir, name, bootstrapPackages, condaChannelUrls)
+ }
+
+ /**
+ * Create a condarc that only exposes package and env directories under the given baseRoot,
+ * on top of the from the default pkgs directory inferred from condaBinaryPath.
+ *
+ * The file will be placed directly inside the given `baseRoot` dir, and link to `baseRoot/pkgs`
+ * as the first package cache.
+ *
+ * This hack is necessary otherwise conda tries to use the homedir for pkgs cache.
+ */
+ private[this] def generateCondarc(baseRoot: Path): Path = {
+ val condaPkgsPath = Paths.get(condaBinaryPath).getParent.getParent.resolve("pkgs")
+ val condarc = baseRoot.resolve("condarc")
+ val condarcContents =
+ s"""pkgs_dirs:
+ | - $baseRoot/pkgs
+ | - $condaPkgsPath
+ |envs_dirs:
+ | - $baseRoot/envs
+ |show_channel_urls: false
+ """.stripMargin
+ Files.write(condarc, List(condarcContents).asJava)
+ logInfo(f"Using condarc at $condarc:%n$condarcContents")
+ condarc
+ }
+
+ private[conda] def runCondaProcess(baseRoot: Path,
+ args: List[String],
+ description: String): Unit = {
+ val condarc = generateCondarc(baseRoot)
+ val fakeHomeDir = baseRoot.resolve("home")
+ // Attempt to create fake home dir
+ Files.createDirectories(fakeHomeDir)
+
+ val extraEnv = List(
+ "CONDARC" -> condarc.toString,
+ "HOME" -> fakeHomeDir.toString
+ )
+
+ val command = Process(
+ condaBinaryPath :: args,
+ None,
+ extraEnv: _*
+ )
+
+ logInfo(s"About to execute $command with environment $extraEnv")
+ runOrFail(command, description)
+ logInfo(s"Successfully executed $command with environment $extraEnv")
+ }
+
+ private[this] def runOrFail(command: ProcessBuilder, description: String): Unit = {
+ val buffer = new StringBuffer
+ val collectErrOutToBuffer = new ProcessIO(
+ BasicIO.input(false),
+ BasicIO.processFully(buffer),
+ BasicIO.processFully(buffer))
+ val exitCode = command.run(collectErrOutToBuffer).exitValue()
+ if (exitCode != 0) {
+ throw new SparkException(s"Attempt to $description exited with code: "
+ + f"$exitCode%nCommand was: $command%nOutput was:%n${buffer.toString}")
+ }
+ }
+}
+
+object CondaEnvironmentManager {
+ def isConfigured(sparkConf: SparkConf): Boolean = {
+ sparkConf.contains(CONDA_BINARY_PATH)
+ }
+
+ def fromConf(sparkConf: SparkConf): CondaEnvironmentManager = {
+ val condaBinaryPath = sparkConf.get(CONDA_BINARY_PATH).getOrElse(
+ sys.error(s"Expected config ${CONDA_BINARY_PATH.key} to be set"))
+ val condaChannelUrls = sparkConf.get(CONDA_CHANNEL_URLS)
+ require(condaChannelUrls.nonEmpty,
+ s"Must define at least one conda channel in config ${CONDA_CHANNEL_URLS.key}")
+ val verbosity = sparkConf.get(CONDA_VERBOSITY)
+ new CondaEnvironmentManager(condaBinaryPath, condaChannelUrls, verbosity)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
index 9481156bc93a5..df2064a310217 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
@@ -695,6 +695,15 @@ class JavaSparkContext(val sc: SparkContext)
sc.addJar(path)
}
+ /**
+ * Add a set of conda packages (identified by package match specification
+ * for all tasks to be executed on this SparkContext in the future.
+ */
+ def addCondaPackages(packages: java.util.List[String]): Unit = {
+ sc.addCondaPackages(packages.asScala)
+ }
+
/**
* Returns the Hadoop configuration used for the Hadoop code (e.g. file systems) we reuse.
*
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 04ae97ed3ccbe..6f31d8d553b1a 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -33,6 +33,7 @@ import org.apache.hadoop.mapred.{InputFormat, JobConf, OutputFormat}
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, OutputFormat => NewOutputFormat}
import org.apache.spark._
+import org.apache.spark.api.conda.CondaEnvironment.CondaSetupInstructions
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.input.PortableDataStream
@@ -72,7 +73,8 @@ private[spark] case class PythonFunction(
command: Array[Byte],
envVars: JMap[String, String],
pythonIncludes: JList[String],
- pythonExec: String,
+ condaSetupInstructions: Option[CondaSetupInstructions],
+ pythonExec: Option[String],
pythonVer: String,
broadcastVars: JList[Broadcast[PythonBroadcast]],
accumulator: PythonAccumulatorV2)
@@ -106,10 +108,14 @@ private[spark] class PythonRunner(
require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs")
+ private[this] val localdirs =
+ SparkEnv.get.blockManager.diskBlockManager.localDirs.map(f => f.getPath).mkString(",")
+
// All the Python functions should have the same exec, version and envvars.
private val envVars = funcs.head.funcs.head.envVars
private val pythonExec = funcs.head.funcs.head.pythonExec
private val pythonVer = funcs.head.funcs.head.pythonVer
+ private val condaInstructions = funcs.head.funcs.head.condaSetupInstructions
// TODO: support accumulator in multiple UDF
private val accumulator = funcs.head.funcs.head.accumulator
@@ -121,12 +127,13 @@ private[spark] class PythonRunner(
val startTime = System.currentTimeMillis
val env = SparkEnv.get
val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",")
- envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor thread
+ envVars.put("SPARK_LOCAL_DIRS", localdirs) // it's also used in monitor thread
if (reuse_worker) {
envVars.put("SPARK_REUSE_WORKER", "1")
}
- val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap)
- // Whether is the worker released into idle pool
+ val worker: Socket = env
+ .createPythonWorker(pythonExec, envVars.asScala.toMap, condaInstructions)
+ // Whether the worker is released into the idle pool
@volatile var released = false
// Start a thread to feed the process input from our parent's iterator
@@ -205,7 +212,8 @@ private[spark] class PythonRunner(
// Check whether the worker is ready to be re-used.
if (stream.readInt() == SpecialLengths.END_OF_STREAM) {
if (reuse_worker) {
- env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker)
+ env.releasePythonWorker(pythonExec, envVars.asScala.toMap,
+ condaInstructions, worker)
released = true
}
}
@@ -371,7 +379,7 @@ private[spark] class PythonRunner(
if (!context.isCompleted) {
try {
logWarning("Incomplete task interrupted: Attempting to kill Python Worker")
- env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker)
+ env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, condaInstructions, worker)
} catch {
case e: Exception =>
logError("Exception when trying to kill worker", e)
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
index 6a5e6f7c5afb1..1904fb3f0e2fb 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
@@ -26,10 +26,15 @@ import scala.collection.mutable
import scala.collection.JavaConverters._
import org.apache.spark._
+import org.apache.spark.api.conda.CondaEnvironment.CondaSetupInstructions
+import org.apache.spark.api.conda.CondaEnvironmentManager
import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.CONDA_BOOTSTRAP_PACKAGES
import org.apache.spark.util.{RedirectThread, Utils}
-private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String])
+private[spark] class PythonWorkerFactory(requestedPythonExec: Option[String],
+ requestedEnvVars: Map[String, String],
+ condaInstructions: Option[CondaSetupInstructions])
extends Logging {
import PythonWorkerFactory._
@@ -50,6 +55,38 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
var simpleWorkers = new mutable.WeakHashMap[Socket, Process]()
+ private[this] val condaEnv = {
+ // Set up conda environment if there are any conda packages requested
+ condaInstructions.map { instructions =>
+ val condaPackages = instructions.packages
+
+ val env = SparkEnv.get
+ val condaEnvManager = CondaEnvironmentManager.fromConf(env.conf)
+ val envDir = {
+ // Which local dir to create it in?
+ val localDirs = env.blockManager.diskBlockManager.localDirs
+ val hash = Utils.nonNegativeHash(condaPackages)
+ val dirId = hash % localDirs.length
+ Utils.createTempDir(localDirs(dirId).getAbsolutePath, "conda").getAbsolutePath
+ }
+ condaEnvManager.create(envDir, condaPackages)
+ }
+ }
+
+ private[this] val envVars: Map[String, String] = {
+ condaEnv.map(_.activatedEnvironment(requestedEnvVars)).getOrElse(requestedEnvVars)
+ }
+
+ private[this] val pythonExec = {
+ condaEnv.map { conda =>
+ requestedPythonExec.foreach(exec => sys.error(s"It's forbidden to set the PYSPARK_PYTHON " +
+ s"when using conda, but found: $exec"))
+
+ conda.condaEnvDir + "/bin/python"
+ }.orElse(requestedPythonExec)
+ .getOrElse("python")
+ }
+
val pythonPath = PythonUtils.mergePythonPaths(
PythonUtils.sparkPythonPath,
envVars.getOrElse("PYTHONPATH", ""),
diff --git a/core/src/main/scala/org/apache/spark/deploy/CondaRunner.scala b/core/src/main/scala/org/apache/spark/deploy/CondaRunner.scala
new file mode 100644
index 0000000000000..cd78a11936c7a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/CondaRunner.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy
+
+import org.apache.spark.SparkConf
+import org.apache.spark.api.conda.CondaEnvironment
+import org.apache.spark.api.conda.CondaEnvironmentManager
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config._
+import org.apache.spark.util.Utils
+
+/**
+ * A runner template used to launch applications using Conda. It bootstraps a Conda env,
+ * then delegates to the [[CondaRunner.run]].
+ */
+abstract class CondaRunner extends Logging {
+ final def main(args: Array[String]): Unit = {
+ val sparkConf = new SparkConf()
+
+ if (CondaEnvironmentManager.isConfigured(sparkConf)) {
+ val condaBootstrapDeps = sparkConf.get(CONDA_BOOTSTRAP_PACKAGES)
+ val condaBaseDir = Utils.createTempDir(Utils.getLocalDir(sparkConf), "conda").getAbsolutePath
+ val condaEnvironmentManager = CondaEnvironmentManager.fromConf(sparkConf)
+ val environment = condaEnvironmentManager.create(condaBaseDir, condaBootstrapDeps)
+
+ // Save this as a global in order for SparkContext to be able to access it later, in case we
+ // are shelling out, but providing a bridge back into this JVM.
+ CondaRunner.condaEnvironment = Some(environment)
+
+ run(args, Some(environment))
+ } else {
+ run(args, None)
+ }
+ }
+
+ def run(args: Array[String], maybeConda: Option[CondaEnvironment]): Unit
+}
+
+object CondaRunner {
+ private[spark] var condaEnvironment: Option[CondaEnvironment] = None
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
index a8f732b11f6cf..1151202398f86 100644
--- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
@@ -25,25 +25,51 @@ import scala.collection.JavaConverters._
import scala.util.Try
import org.apache.spark.{SparkConf, SparkUserAppException}
+import org.apache.spark.api.conda.CondaEnvironment
import org.apache.spark.api.python.PythonUtils
+import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
-import org.apache.spark.util.{RedirectThread, Utils}
+import org.apache.spark.util.RedirectThread
+import org.apache.spark.util.Utils
/**
* A main class used to launch Python applications. It executes python as a
* subprocess and then has it connect back to the JVM to access system properties, etc.
*/
-object PythonRunner {
- def main(args: Array[String]) {
+object PythonRunner extends CondaRunner with Logging {
+ private[this] case class Provenance(from: String, value: String) {
+ override def toString: String = s"Provenance(from = $from, value = $value)"
+ }
+
+ private[this] object Provenance {
+ def fromConf(sparkConf: SparkConf, conf: ConfigEntry[Option[String]]): Option[Provenance] = {
+ sparkConf.get(conf).map(Provenance(s"Spark config ${conf.key}", _))
+ }
+ def fromEnv(name: String): Option[Provenance] = {
+ sys.env.get(name).map(Provenance(s"Environment variable $name", _))
+ }
+ }
+
+ override def run(args: Array[String], maybeConda: Option[CondaEnvironment]): Unit = {
val pythonFile = args(0)
val pyFiles = args(1)
val otherArgs = args.slice(2, args.length)
val sparkConf = new SparkConf()
- val pythonExec = sparkConf.get(PYSPARK_DRIVER_PYTHON)
- .orElse(sparkConf.get(PYSPARK_PYTHON))
- .orElse(sys.env.get("PYSPARK_DRIVER_PYTHON"))
- .orElse(sys.env.get("PYSPARK_PYTHON"))
- .getOrElse("python")
+ val presetPythonExec = Provenance.fromConf(sparkConf, PYSPARK_DRIVER_PYTHON)
+ .orElse(Provenance.fromConf(sparkConf, PYSPARK_PYTHON))
+ .orElse(Provenance.fromEnv("PYSPARK_DRIVER_PYTHON"))
+ .orElse(Provenance.fromEnv("PYSPARK_PYTHON"))
+
+ val pythonExec: String = maybeConda.map { conda =>
+ presetPythonExec.foreach { exec =>
+ sys.error(
+ s"It's forbidden to set the PYSPARK python path when using conda, but found: $exec")
+ }
+ conda.condaEnvDir + "/bin/python"
+ }.orElse(presetPythonExec.map(_.value))
+ .getOrElse("python")
+
+ logInfo(s"Python binary that will be called: $pythonExec")
// Format python file paths before adding them to the PYTHONPATH
val formattedPythonFile = formatPath(pythonFile)
@@ -78,6 +104,8 @@ object PythonRunner {
// Launch Python process
val builder = new ProcessBuilder((Seq(pythonExec, formattedPythonFile) ++ otherArgs).asJava)
val env = builder.environment()
+ // If there is a CondaEnvironment set up, initialise our process' env from that
+ maybeConda.foreach(_.initializeJavaEnvironment(env))
env.put("PYTHONPATH", pythonPath)
// This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index 223c921810378..0928f569e4906 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -169,6 +169,30 @@ package object config {
.stringConf
.createOptional
+ private[spark] val CONDA_BOOTSTRAP_PACKAGES = ConfigBuilder("spark.conda.bootstrapPackages")
+ .doc("The packages that will be added to the conda environment. "
+ + "Only relevant when main class is CondaRunner.")
+ .stringConf
+ .toSequence
+ .createWithDefault(Nil)
+
+ private[spark] val CONDA_CHANNEL_URLS = ConfigBuilder("spark.conda.channelUrls")
+ .doc("The URLs the Conda channels to use when resolving the conda packages. "
+ + "Only relevant when main class is CondaRunner.")
+ .stringConf
+ .toSequence
+ .createWithDefault(Nil)
+
+ private[spark] val CONDA_BINARY_PATH = ConfigBuilder("spark.conda.binaryPath")
+ .doc("The location of the conda binary. Only relevant when main class is CondaRunner.")
+ .stringConf
+ .createOptional
+
+ private[spark] val CONDA_VERBOSITY = ConfigBuilder("spark.conda.verbosity")
+ .doc("How many times to apply -v to conda. A number between 0 and 3, with 0 being default.")
+ .intConf
+ .createWithDefault(0)
+
// To limit memory usage, we only track information for a fixed number of tasks
private[spark] val UI_RETAINED_TASKS = ConfigBuilder("spark.ui.retainedTasks")
.intConf
diff --git a/dev/run-tests.py b/dev/run-tests.py
index 8ce30f54faace..b18e1d511260e 100755
--- a/dev/run-tests.py
+++ b/dev/run-tests.py
@@ -270,7 +270,7 @@ def exec_maven(mvn_args=()):
kill_zinc_on_port(zinc_port)
-def exec_sbt(sbt_args=(), exit_on_failure=True):
+def exec_sbt(sbt_args=()):
"""Will call SBT in the current directory with the list of mvn_args passed
in and returns the subprocess for any further processing"""
@@ -295,8 +295,7 @@ def exec_sbt(sbt_args=(), exit_on_failure=True):
retcode = sbt_proc.wait()
if retcode != 0:
- if exit_on_failure:
- exit_from_command_with_retcode(sbt_cmd, retcode)
+ exit_from_command_with_retcode(sbt_cmd, retcode)
return sbt_cmd, retcode
@@ -406,7 +405,7 @@ def run_scala_tests_sbt(test_modules, test_profiles):
print("[info] Running Spark tests using SBT with these arguments: ",
" ".join(profiles_and_goals))
- exec_sbt(profiles_and_goals)
+ sbt_cmd, retcode = exec_sbt(profiles_and_goals)
def run_scala_tests(build_tool, hadoop_version, test_modules, excluded_tags):
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 56b8c0b95e8a4..721e901be97f1 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -49,6 +49,9 @@ object MimaExcludes {
// [SPARK-18537] Add a REST api to spark streaming
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.streaming.scheduler.StreamingListener.onStreamingStarted"),
+ // CondaRunner is meant to own the main() method then delegate to another method
+ ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.deploy.CondaRunner.main"),
+
// [SPARK-19148][SQL] do not expose the external table concept in Catalog
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.createTable"),
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 2961cda553d6a..c28a7709f56e2 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -188,7 +188,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,
self._javaAccumulator = self._jvm.PythonAccumulatorV2(host, port)
self._jsc.sc().register(self._javaAccumulator)
- self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
+ self.pythonExec = self._jvm.scala.Option.apply(os.environ.get("PYSPARK_PYTHON"))
self.pythonVer = "%d.%d" % sys.version_info[:2]
if sys.version_info < (2, 7):
@@ -837,6 +837,20 @@ def addPyFile(self, path):
import importlib
importlib.invalidate_caches()
+ def addCondaPackages(self, *packages):
+ """
+ Add a conda `package match specification
+ `_ for all tasks to be executed on
+ this SparkContext in the future.
+ """
+ self._jsc.addCondaPackages(packages)
+
+ def addCondaChannel(self, url):
+ self._jsc.sc().addCondaChannel(url)
+
+ def _build_conda_instructions(self):
+ return self._jsc.sc().buildCondaInstructions()
+
def setCheckpointDir(self, dirName):
"""
Set the directory under which RDDs are going to be checkpointed. The
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
index 3c783ae541a1f..66d5c783dc367 100644
--- a/python/pyspark/java_gateway.py
+++ b/python/pyspark/java_gateway.py
@@ -115,6 +115,7 @@ def killChild():
# Import the classes used by PySpark
java_import(gateway.jvm, "org.apache.spark.SparkConf")
+ java_import(gateway.jvm, "org.apache.spark.api.conda.*")
java_import(gateway.jvm, "org.apache.spark.api.java.*")
java_import(gateway.jvm, "org.apache.spark.api.python.*")
java_import(gateway.jvm, "org.apache.spark.ml.python.*")
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index a5e6e2b054963..4a7b66268221a 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -2370,8 +2370,9 @@ def _wrap_function(sc, func, deserializer, serializer, profiler=None):
assert serializer, "serializer should not be empty"
command = (func, profiler, deserializer, serializer)
pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command)
- return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec,
- sc.pythonVer, broadcast_vars, sc._javaAccumulator)
+ return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes,
+ sc._build_conda_instructions(), sc.pythonExec, sc.pythonVer,
+ broadcast_vars, sc._javaAccumulator)
class PipelinedRDD(RDD):
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 376b86ea69bd4..ef1e0cc29dd6a 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1859,8 +1859,9 @@ def sort_array(col, asc=True):
def _wrap_function(sc, func, returnType):
command = (func, returnType)
pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command)
- return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec,
- sc.pythonVer, broadcast_vars, sc._javaAccumulator)
+ return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes,
+ sc._build_conda_instructions(), sc.pythonExec, sc.pythonVer,
+ broadcast_vars, sc._javaAccumulator)
class UserDefinedFunction(object):
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index c6c87a9ea5555..b7704d0d5395d 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -2057,6 +2057,35 @@ def test_user_configuration(self):
out, err = proc.communicate()
self.assertEqual(0, proc.returncode, msg="Process failed with error:\n {0}".format(out))
+ def test_conda(self):
+ """Submit and test a single script file via conda"""
+ script = self.createTempFile("test.py", """
+ |from pyspark import SparkContext
+ |
+ |sc = SparkContext()
+ |sc.addCondaPackages('numpy=1.11.1')
+ |
+ |# Ensure numpy is accessible on the driver
+ |import numpy
+ |arr = [1, 2, 3]
+ |def mul2(x):
+ | # Also ensure numpy accessible from executor
+ | assert numpy.version.version == "1.11.1"
+ | return x * 2
+ |print(sc.parallelize(arr).map(mul2).collect())
+ """)
+ props = self.createTempFile("properties", """
+ |spark.conda.binaryPath {}
+ |spark.conda.channelUrls https://repo.continuum.io/pkgs/free
+ |spark.conda.bootstrapPackages python=3.5
+ """.format(os.environ["CONDA_BIN"]))
+ proc = subprocess.Popen([self.sparkSubmit,
+ "--properties-file", props,
+ script], stdout=subprocess.PIPE)
+ out, err = proc.communicate()
+ self.assertEqual(0, proc.returncode)
+ self.assertIn("[2, 4, 6]", out.decode('utf-8'))
+
class ContextTests(unittest.TestCase):
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala
index 9c3b18e4ec5f3..7cc3075eb766c 100644
--- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala
@@ -137,7 +137,8 @@ abstract class BaseYarnClusterSuite
extraClassPath: Seq[String] = Nil,
extraJars: Seq[String] = Nil,
extraConf: Map[String, String] = Map(),
- extraEnv: Map[String, String] = Map()): SparkAppHandle.State = {
+ extraEnv: Map[String, String] = Map(),
+ timeoutDuration: FiniteDuration = 2.minutes): SparkAppHandle.State = {
val deployMode = if (clientMode) "client" else "cluster"
val propsFile = createConfFile(extraClassPath = extraClassPath, extraConf = extraConf)
val env = Map("YARN_CONF_DIR" -> hadoopConfDir.getAbsolutePath()) ++ extraEnv
@@ -167,7 +168,7 @@ abstract class BaseYarnClusterSuite
val handle = launcher.startApplication()
try {
- eventually(timeout(2 minutes), interval(1 second)) {
+ eventually(timeout(timeoutDuration), interval(1 second)) {
assert(handle.getState().isFinal())
}
} finally {
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
index 99fb58a28934a..40dc2ec26a5fd 100644
--- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
@@ -74,6 +74,36 @@ class YarnClusterSuite extends BaseYarnClusterSuite {
| sc.stop()
""".stripMargin
+ private val TEST_CONDA_PYFILE = """
+ |import mod1, mod2
+ |import sys
+ |from operator import add
+ |
+ |from pyspark import SparkConf , SparkContext
+ |if __name__ == "__main__":
+ | if len(sys.argv) != 2:
+ | print >> sys.stderr, "Usage: test.py [result file]"
+ | exit(-1)
+ | sc = SparkContext(conf=SparkConf())
+ |
+ | sc.addCondaPackages('numpy=1.11.1')
+ | import numpy
+ |
+ | status = open(sys.argv[1],'w')
+ |
+ | numpy_multiply = lambda x: numpy.multiply(x, mod1.func() * mod2.func())
+ |
+ | rdd = sc.parallelize(range(10)).map(numpy_multiply)
+ | cnt = rdd.count()
+ | if cnt == 10:
+ | result = "success"
+ | else:
+ | result = "failure"
+ | status.write(result)
+ | status.close()
+ | sc.stop()
+ """.stripMargin
+
private val TEST_PYMODULE = """
|def func():
| return 42
@@ -139,6 +169,14 @@ class YarnClusterSuite extends BaseYarnClusterSuite {
testPySpark(false)
}
+ test("run Python application within Conda in yarn-client mode") {
+ testCondaPySpark(true)
+ }
+
+ test("run Python application within Conda in yarn-cluster mode") {
+ testCondaPySpark(false)
+ }
+
test("run Python application in yarn-cluster mode using " +
" spark.yarn.appMasterEnv to override local envvar") {
testPySpark(
@@ -273,6 +311,55 @@ class YarnClusterSuite extends BaseYarnClusterSuite {
checkResult(finalState, result)
}
+ private def testCondaPySpark(
+ clientMode: Boolean,
+ extraEnv: Map[String, String] = Map()): Unit = {
+ val primaryPyFile = new File(tempDir, "test.py")
+ Files.write(TEST_CONDA_PYFILE, primaryPyFile, StandardCharsets.UTF_8)
+
+ // When running tests, let's not assume the user has built the assembly module, which also
+ // creates the pyspark archive. Instead, let's use PYSPARK_ARCHIVES_PATH to point at the
+ // needed locations.
+ val sparkHome = sys.props("spark.test.home")
+ val pythonPath = Seq(
+ s"$sparkHome/python/lib/py4j-0.10.4-src.zip",
+ s"$sparkHome/python")
+ val extraEnvVars = Map(
+ "PYSPARK_ARCHIVES_PATH" -> pythonPath.map("local:" + _).mkString(File.pathSeparator),
+ "PYTHONPATH" -> pythonPath.mkString(File.pathSeparator)) ++ extraEnv
+
+ val extraConf: Map[String, String] = Map(
+ "spark.conda.binaryPath" -> sys.env("CONDA_BIN"),
+ "spark.conda.channelUrls" -> "https://repo.continuum.io/pkgs/free",
+ "spark.conda.bootstrapPackages" -> "python=3.5"
+ )
+
+ val moduleDir =
+ if (clientMode) {
+ // In client-mode, .py files added with --py-files are not visible in the driver.
+ // This is something that the launcher library would have to handle.
+ tempDir
+ } else {
+ val subdir = new File(tempDir, "pyModules")
+ subdir.mkdir()
+ subdir
+ }
+ val pyModule = new File(moduleDir, "mod1.py")
+ Files.write(TEST_PYMODULE, pyModule, StandardCharsets.UTF_8)
+
+ val mod2Archive = TestUtils.createJarWithFiles(Map("mod2.py" -> TEST_PYMODULE), moduleDir)
+ val pyFiles = Seq(pyModule.getAbsolutePath(), mod2Archive.getPath()).mkString(",")
+ val result = File.createTempFile("result", null, tempDir)
+
+ val finalState = runSpark(clientMode, primaryPyFile.getAbsolutePath(),
+ sparkArgs = Seq("--py-files" -> pyFiles),
+ appArgs = Seq(result.getAbsolutePath()),
+ extraEnv = extraEnvVars,
+ extraConf = extraConf,
+ timeoutDuration = 4.minutes) // give it a bit longer
+ checkResult(finalState, result)
+ }
+
private def testUseClassPathFirst(clientMode: Boolean): Unit = {
// Create a jar file that contains a different version of "test.resource".
val originalJar = TestUtils.createJarWithFiles(Map("test.resource" -> "ORIGINAL"), tempDir)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
index 2a3d1cf0b298a..0d67796efe72b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
@@ -99,7 +99,8 @@ class DummyUDF extends PythonFunction(
command = Array[Byte](),
envVars = Map("" -> "").asJava,
pythonIncludes = ArrayBuffer("").asJava,
- pythonExec = "",
+ condaSetupInstructions = None,
+ pythonExec = None,
pythonVer = "",
broadcastVars = null,
accumulator = null)