diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 0e36a30c933d0..7aba4e5abeb5e 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -42,8 +42,10 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHad import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.api.conda.CondaEnvironment +import org.apache.spark.api.conda.CondaEnvironment.CondaSetupInstructions import org.apache.spark.broadcast.Broadcast -import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} +import org.apache.spark.deploy.{CondaRunner, LocalSparkCluster, SparkHadoopUtil} import org.apache.spark.input.{FixedLengthBinaryInputFormat, PortableDataStream, StreamInputFormat, WholeTextFileInputFormat} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ @@ -336,6 +338,9 @@ class SparkContext(config: SparkConf) extends Logging { override protected def initialValue(): Properties = new Properties() } + // Retrieve the Conda Environment from CondaRunner if it has set one up for us + val condaEnvironment: Option[CondaEnvironment] = CondaRunner.condaEnvironment + /* ------------------------------------------------------------------------------------- * | Initialization. This code initializes the context in a manner that is exception-safe. | | All internal fields holding state are initialized here, and any error prompts the | @@ -1851,6 +1856,28 @@ class SparkContext(config: SparkConf) extends Logging { */ def listJars(): Seq[String] = addedJars.keySet.toSeq + private[this] def condaEnvironmentOrFail(): CondaEnvironment = { + condaEnvironment.getOrElse(sys.error("A conda environment was not set up.")) + } + + /** + * Add a set of conda packages (identified by package match specification + * for all tasks to be executed on this SparkContext in the future. + */ + def addCondaPackages(packages: Seq[String]): Unit = { + condaEnvironmentOrFail().installPackages(packages) + } + + def addCondaChannel(url: String): Unit = { + condaEnvironmentOrFail().addChannel(url) + } + + private[spark] def buildCondaInstructions(): Option[CondaSetupInstructions] = { + condaEnvironment.map(_.buildSetupInstructions) + } + + /** * When stopping SparkContext inside Spark components, it's easy to cause dead-lock since Spark * may wait for some internal threads to finish. It's better to use this method to stop diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 539dbb55eeff0..cb4f751bc8fb7 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -26,6 +26,7 @@ import scala.util.Properties import com.google.common.collect.MapMaker import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.api.conda.CondaEnvironment.CondaSetupInstructions import org.apache.spark.api.python.PythonWorkerFactory import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.internal.Logging @@ -70,7 +71,10 @@ class SparkEnv ( val conf: SparkConf) extends Logging { private[spark] var isStopped = false - private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() + + case class PythonWorkerKey(pythonExec: Option[String], envVars: Map[String, String], + condaInstructions: Option[CondaSetupInstructions]) + private val pythonWorkers = mutable.HashMap[PythonWorkerKey, PythonWorkerFactory]() // A general, soft-reference map for metadata needed during HadoopRDD split computation // (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats). @@ -110,25 +114,29 @@ class SparkEnv ( } private[spark] - def createPythonWorker(pythonExec: String, envVars: Map[String, String]): java.net.Socket = { + def createPythonWorker(pythonExec: Option[String], envVars: Map[String, String], + condaInstructions: Option[CondaSetupInstructions]): java.net.Socket = { synchronized { - val key = (pythonExec, envVars) - pythonWorkers.getOrElseUpdate(key, new PythonWorkerFactory(pythonExec, envVars)).create() + val key = PythonWorkerKey(pythonExec, envVars, condaInstructions) + pythonWorkers.getOrElseUpdate(key, + new PythonWorkerFactory(pythonExec, envVars, condaInstructions)).create() } } private[spark] - def destroyPythonWorker(pythonExec: String, envVars: Map[String, String], worker: Socket) { + def destroyPythonWorker(pythonExec: Option[String], envVars: Map[String, String], + condaInstructions: Option[CondaSetupInstructions], worker: Socket) { synchronized { - val key = (pythonExec, envVars) + val key = PythonWorkerKey(pythonExec, envVars, condaInstructions) pythonWorkers.get(key).foreach(_.stopWorker(worker)) } } private[spark] - def releasePythonWorker(pythonExec: String, envVars: Map[String, String], worker: Socket) { + def releasePythonWorker(pythonExec: Option[String], envVars: Map[String, String], + condaInstructions: Option[CondaSetupInstructions], worker: Socket) { synchronized { - val key = (pythonExec, envVars) + val key = PythonWorkerKey(pythonExec, envVars, condaInstructions) pythonWorkers.get(key).foreach(_.releaseWorker(worker)) } } diff --git a/core/src/main/scala/org/apache/spark/api/conda/CondaEnvironment.scala b/core/src/main/scala/org/apache/spark/api/conda/CondaEnvironment.scala new file mode 100644 index 0000000000000..b7cfa6b0e3659 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/conda/CondaEnvironment.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.api.conda + +import java.io.File +import java.nio.file.Path +import java.util.{Map => JMap} + +import scala.collection.mutable + +import org.apache.spark.internal.Logging + +/** + * A stateful class that describes a Conda environment and also keeps track of packages that have + * been added, as well as additional channels. + * + * @param rootPath The root path under which envs/ and pkgs/ are located. + * @param envName The name of the environment. + */ +final class CondaEnvironment(val manager: CondaEnvironmentManager, + val rootPath: Path, + val envName: String, + bootstrapPackages: Seq[String], + bootstrapChannels: Seq[String]) extends Logging { + + import CondaEnvironment._ + + private[this] val packages = mutable.Buffer(bootstrapPackages: _*) + private[this] val channels = bootstrapChannels.toBuffer + + val condaEnvDir: Path = rootPath.resolve("envs").resolve(envName) + + def activatedEnvironment(startEnv: Map[String, String] = Map.empty): Map[String, String] = { + require(!startEnv.contains("PATH"), "Defining PATH in a CondaEnvironment's startEnv is " + + s"prohibited; found PATH=${startEnv("PATH")}") + import collection.JavaConverters._ + val newVars = System.getenv().asScala.toIterator ++ startEnv ++ List( + "CONDA_PREFIX" -> condaEnvDir.toString, + "CONDA_DEFAULT_ENV" -> condaEnvDir.toString, + "PATH" -> (condaEnvDir.resolve("bin").toString + + sys.env.get("PATH").map(File.pathSeparator + _).getOrElse("")) + ) + newVars.toMap + } + + def addChannel(url: String): Unit = { + channels += url + } + + def installPackages(packages: Seq[String]): Unit = { + manager.runCondaProcess(rootPath, + List("install", "-n", envName, "-y", "--override-channels") + ::: channels.iterator.flatMap(Iterator("--channel", _)).toList + ::: "--" :: packages.toList, + description = s"install dependencies in conda env $condaEnvDir" + ) + + this.packages ++= packages + } + + /** + * Clears the given java environment and replaces all variables with the environment + * produced after calling `activate` inside this conda environment. + */ + def initializeJavaEnvironment(env: JMap[String, String]): Unit = { + env.clear() + val activatedEnv = activatedEnvironment() + activatedEnv.foreach { case (k, v) => env.put(k, v) } + logDebug(s"Initialised environment from conda: $activatedEnv") + } + + /** + * This is for sending the instructions to the executors so they can replicate the same steps. + */ + def buildSetupInstructions: CondaSetupInstructions = { + CondaSetupInstructions(packages.toList, channels.toList) + } +} + +object CondaEnvironment { + case class CondaSetupInstructions(packages: Seq[String], channels: Seq[String]) { + require(channels.nonEmpty) + require(packages.nonEmpty) + } +} diff --git a/core/src/main/scala/org/apache/spark/api/conda/CondaEnvironmentManager.scala b/core/src/main/scala/org/apache/spark/api/conda/CondaEnvironmentManager.scala new file mode 100644 index 0000000000000..32d64e4a31a1f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/conda/CondaEnvironmentManager.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.api.conda + +import java.nio.file.Files +import java.nio.file.Path +import java.nio.file.Paths + +import scala.collection.JavaConverters._ +import scala.sys.process.BasicIO +import scala.sys.process.Process +import scala.sys.process.ProcessBuilder +import scala.sys.process.ProcessIO + +import org.apache.spark.SparkConf +import org.apache.spark.SparkException +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.CONDA_BINARY_PATH +import org.apache.spark.internal.config.CONDA_CHANNEL_URLS +import org.apache.spark.internal.config.CONDA_VERBOSITY +import org.apache.spark.util.Utils + +final class CondaEnvironmentManager(condaBinaryPath: String, condaChannelUrls: Seq[String], + verbosity: Int = 0) + extends Logging { + + require(condaChannelUrls.nonEmpty, "Can't have an empty list of conda channel URLs") + require(verbosity >= 0 && verbosity <= 3, "Verbosity must be between 0 and 3 inclusively") + + def create( + baseDir: String, + bootstrapPackages: Seq[String]): CondaEnvironment = { + require(bootstrapPackages.nonEmpty, "Expected at least one bootstrap package.") + val name = "conda-env" + + // must link in /tmp to reduce path length in case baseDir is very long... + // If baseDir path is too long, this breaks conda's 220 character limit for binary replacement. + // Don't even try to use java.io.tmpdir - yarn sets this to a very long path + val linkedBaseDir = Utils.createTempDir("/tmp", "conda").toPath.resolve("real") + logInfo(s"Creating symlink $linkedBaseDir -> $baseDir") + Files.createSymbolicLink(linkedBaseDir, Paths.get(baseDir)) + + val verbosityFlags = 0.until(verbosity).map(_ => "-v").toList + + // Attempt to create environment + runCondaProcess( + linkedBaseDir, + List("create", "-n", name, "-y", "--override-channels", "--no-default-packages") + ::: verbosityFlags + ::: condaChannelUrls.flatMap(Iterator("--channel", _)).toList + ::: "--" :: bootstrapPackages.toList, + description = "create conda env" + ) + + new CondaEnvironment(this, linkedBaseDir, name, bootstrapPackages, condaChannelUrls) + } + + /** + * Create a condarc that only exposes package and env directories under the given baseRoot, + * on top of the from the default pkgs directory inferred from condaBinaryPath. + * + * The file will be placed directly inside the given `baseRoot` dir, and link to `baseRoot/pkgs` + * as the first package cache. + * + * This hack is necessary otherwise conda tries to use the homedir for pkgs cache. + */ + private[this] def generateCondarc(baseRoot: Path): Path = { + val condaPkgsPath = Paths.get(condaBinaryPath).getParent.getParent.resolve("pkgs") + val condarc = baseRoot.resolve("condarc") + val condarcContents = + s"""pkgs_dirs: + | - $baseRoot/pkgs + | - $condaPkgsPath + |envs_dirs: + | - $baseRoot/envs + |show_channel_urls: false + """.stripMargin + Files.write(condarc, List(condarcContents).asJava) + logInfo(f"Using condarc at $condarc:%n$condarcContents") + condarc + } + + private[conda] def runCondaProcess(baseRoot: Path, + args: List[String], + description: String): Unit = { + val condarc = generateCondarc(baseRoot) + val fakeHomeDir = baseRoot.resolve("home") + // Attempt to create fake home dir + Files.createDirectories(fakeHomeDir) + + val extraEnv = List( + "CONDARC" -> condarc.toString, + "HOME" -> fakeHomeDir.toString + ) + + val command = Process( + condaBinaryPath :: args, + None, + extraEnv: _* + ) + + logInfo(s"About to execute $command with environment $extraEnv") + runOrFail(command, description) + logInfo(s"Successfully executed $command with environment $extraEnv") + } + + private[this] def runOrFail(command: ProcessBuilder, description: String): Unit = { + val buffer = new StringBuffer + val collectErrOutToBuffer = new ProcessIO( + BasicIO.input(false), + BasicIO.processFully(buffer), + BasicIO.processFully(buffer)) + val exitCode = command.run(collectErrOutToBuffer).exitValue() + if (exitCode != 0) { + throw new SparkException(s"Attempt to $description exited with code: " + + f"$exitCode%nCommand was: $command%nOutput was:%n${buffer.toString}") + } + } +} + +object CondaEnvironmentManager { + def isConfigured(sparkConf: SparkConf): Boolean = { + sparkConf.contains(CONDA_BINARY_PATH) + } + + def fromConf(sparkConf: SparkConf): CondaEnvironmentManager = { + val condaBinaryPath = sparkConf.get(CONDA_BINARY_PATH).getOrElse( + sys.error(s"Expected config ${CONDA_BINARY_PATH.key} to be set")) + val condaChannelUrls = sparkConf.get(CONDA_CHANNEL_URLS) + require(condaChannelUrls.nonEmpty, + s"Must define at least one conda channel in config ${CONDA_CHANNEL_URLS.key}") + val verbosity = sparkConf.get(CONDA_VERBOSITY) + new CondaEnvironmentManager(condaBinaryPath, condaChannelUrls, verbosity) + } +} diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 9481156bc93a5..df2064a310217 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -695,6 +695,15 @@ class JavaSparkContext(val sc: SparkContext) sc.addJar(path) } + /** + * Add a set of conda packages (identified by package match specification + * for all tasks to be executed on this SparkContext in the future. + */ + def addCondaPackages(packages: java.util.List[String]): Unit = { + sc.addCondaPackages(packages.asScala) + } + /** * Returns the Hadoop configuration used for the Hadoop code (e.g. file systems) we reuse. * diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 04ae97ed3ccbe..6f31d8d553b1a 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -33,6 +33,7 @@ import org.apache.hadoop.mapred.{InputFormat, JobConf, OutputFormat} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, OutputFormat => NewOutputFormat} import org.apache.spark._ +import org.apache.spark.api.conda.CondaEnvironment.CondaSetupInstructions import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} import org.apache.spark.broadcast.Broadcast import org.apache.spark.input.PortableDataStream @@ -72,7 +73,8 @@ private[spark] case class PythonFunction( command: Array[Byte], envVars: JMap[String, String], pythonIncludes: JList[String], - pythonExec: String, + condaSetupInstructions: Option[CondaSetupInstructions], + pythonExec: Option[String], pythonVer: String, broadcastVars: JList[Broadcast[PythonBroadcast]], accumulator: PythonAccumulatorV2) @@ -106,10 +108,14 @@ private[spark] class PythonRunner( require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs") + private[this] val localdirs = + SparkEnv.get.blockManager.diskBlockManager.localDirs.map(f => f.getPath).mkString(",") + // All the Python functions should have the same exec, version and envvars. private val envVars = funcs.head.funcs.head.envVars private val pythonExec = funcs.head.funcs.head.pythonExec private val pythonVer = funcs.head.funcs.head.pythonVer + private val condaInstructions = funcs.head.funcs.head.condaSetupInstructions // TODO: support accumulator in multiple UDF private val accumulator = funcs.head.funcs.head.accumulator @@ -121,12 +127,13 @@ private[spark] class PythonRunner( val startTime = System.currentTimeMillis val env = SparkEnv.get val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",") - envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor thread + envVars.put("SPARK_LOCAL_DIRS", localdirs) // it's also used in monitor thread if (reuse_worker) { envVars.put("SPARK_REUSE_WORKER", "1") } - val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap) - // Whether is the worker released into idle pool + val worker: Socket = env + .createPythonWorker(pythonExec, envVars.asScala.toMap, condaInstructions) + // Whether the worker is released into the idle pool @volatile var released = false // Start a thread to feed the process input from our parent's iterator @@ -205,7 +212,8 @@ private[spark] class PythonRunner( // Check whether the worker is ready to be re-used. if (stream.readInt() == SpecialLengths.END_OF_STREAM) { if (reuse_worker) { - env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker) + env.releasePythonWorker(pythonExec, envVars.asScala.toMap, + condaInstructions, worker) released = true } } @@ -371,7 +379,7 @@ private[spark] class PythonRunner( if (!context.isCompleted) { try { logWarning("Incomplete task interrupted: Attempting to kill Python Worker") - env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker) + env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, condaInstructions, worker) } catch { case e: Exception => logError("Exception when trying to kill worker", e) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 6a5e6f7c5afb1..1904fb3f0e2fb 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -26,10 +26,15 @@ import scala.collection.mutable import scala.collection.JavaConverters._ import org.apache.spark._ +import org.apache.spark.api.conda.CondaEnvironment.CondaSetupInstructions +import org.apache.spark.api.conda.CondaEnvironmentManager import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.CONDA_BOOTSTRAP_PACKAGES import org.apache.spark.util.{RedirectThread, Utils} -private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String]) +private[spark] class PythonWorkerFactory(requestedPythonExec: Option[String], + requestedEnvVars: Map[String, String], + condaInstructions: Option[CondaSetupInstructions]) extends Logging { import PythonWorkerFactory._ @@ -50,6 +55,38 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String var simpleWorkers = new mutable.WeakHashMap[Socket, Process]() + private[this] val condaEnv = { + // Set up conda environment if there are any conda packages requested + condaInstructions.map { instructions => + val condaPackages = instructions.packages + + val env = SparkEnv.get + val condaEnvManager = CondaEnvironmentManager.fromConf(env.conf) + val envDir = { + // Which local dir to create it in? + val localDirs = env.blockManager.diskBlockManager.localDirs + val hash = Utils.nonNegativeHash(condaPackages) + val dirId = hash % localDirs.length + Utils.createTempDir(localDirs(dirId).getAbsolutePath, "conda").getAbsolutePath + } + condaEnvManager.create(envDir, condaPackages) + } + } + + private[this] val envVars: Map[String, String] = { + condaEnv.map(_.activatedEnvironment(requestedEnvVars)).getOrElse(requestedEnvVars) + } + + private[this] val pythonExec = { + condaEnv.map { conda => + requestedPythonExec.foreach(exec => sys.error(s"It's forbidden to set the PYSPARK_PYTHON " + + s"when using conda, but found: $exec")) + + conda.condaEnvDir + "/bin/python" + }.orElse(requestedPythonExec) + .getOrElse("python") + } + val pythonPath = PythonUtils.mergePythonPaths( PythonUtils.sparkPythonPath, envVars.getOrElse("PYTHONPATH", ""), diff --git a/core/src/main/scala/org/apache/spark/deploy/CondaRunner.scala b/core/src/main/scala/org/apache/spark/deploy/CondaRunner.scala new file mode 100644 index 0000000000000..cd78a11936c7a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/CondaRunner.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import org.apache.spark.SparkConf +import org.apache.spark.api.conda.CondaEnvironment +import org.apache.spark.api.conda.CondaEnvironmentManager +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ +import org.apache.spark.util.Utils + +/** + * A runner template used to launch applications using Conda. It bootstraps a Conda env, + * then delegates to the [[CondaRunner.run]]. + */ +abstract class CondaRunner extends Logging { + final def main(args: Array[String]): Unit = { + val sparkConf = new SparkConf() + + if (CondaEnvironmentManager.isConfigured(sparkConf)) { + val condaBootstrapDeps = sparkConf.get(CONDA_BOOTSTRAP_PACKAGES) + val condaBaseDir = Utils.createTempDir(Utils.getLocalDir(sparkConf), "conda").getAbsolutePath + val condaEnvironmentManager = CondaEnvironmentManager.fromConf(sparkConf) + val environment = condaEnvironmentManager.create(condaBaseDir, condaBootstrapDeps) + + // Save this as a global in order for SparkContext to be able to access it later, in case we + // are shelling out, but providing a bridge back into this JVM. + CondaRunner.condaEnvironment = Some(environment) + + run(args, Some(environment)) + } else { + run(args, None) + } + } + + def run(args: Array[String], maybeConda: Option[CondaEnvironment]): Unit +} + +object CondaRunner { + private[spark] var condaEnvironment: Option[CondaEnvironment] = None +} diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index a8f732b11f6cf..1151202398f86 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -25,25 +25,51 @@ import scala.collection.JavaConverters._ import scala.util.Try import org.apache.spark.{SparkConf, SparkUserAppException} +import org.apache.spark.api.conda.CondaEnvironment import org.apache.spark.api.python.PythonUtils +import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ -import org.apache.spark.util.{RedirectThread, Utils} +import org.apache.spark.util.RedirectThread +import org.apache.spark.util.Utils /** * A main class used to launch Python applications. It executes python as a * subprocess and then has it connect back to the JVM to access system properties, etc. */ -object PythonRunner { - def main(args: Array[String]) { +object PythonRunner extends CondaRunner with Logging { + private[this] case class Provenance(from: String, value: String) { + override def toString: String = s"Provenance(from = $from, value = $value)" + } + + private[this] object Provenance { + def fromConf(sparkConf: SparkConf, conf: ConfigEntry[Option[String]]): Option[Provenance] = { + sparkConf.get(conf).map(Provenance(s"Spark config ${conf.key}", _)) + } + def fromEnv(name: String): Option[Provenance] = { + sys.env.get(name).map(Provenance(s"Environment variable $name", _)) + } + } + + override def run(args: Array[String], maybeConda: Option[CondaEnvironment]): Unit = { val pythonFile = args(0) val pyFiles = args(1) val otherArgs = args.slice(2, args.length) val sparkConf = new SparkConf() - val pythonExec = sparkConf.get(PYSPARK_DRIVER_PYTHON) - .orElse(sparkConf.get(PYSPARK_PYTHON)) - .orElse(sys.env.get("PYSPARK_DRIVER_PYTHON")) - .orElse(sys.env.get("PYSPARK_PYTHON")) - .getOrElse("python") + val presetPythonExec = Provenance.fromConf(sparkConf, PYSPARK_DRIVER_PYTHON) + .orElse(Provenance.fromConf(sparkConf, PYSPARK_PYTHON)) + .orElse(Provenance.fromEnv("PYSPARK_DRIVER_PYTHON")) + .orElse(Provenance.fromEnv("PYSPARK_PYTHON")) + + val pythonExec: String = maybeConda.map { conda => + presetPythonExec.foreach { exec => + sys.error( + s"It's forbidden to set the PYSPARK python path when using conda, but found: $exec") + } + conda.condaEnvDir + "/bin/python" + }.orElse(presetPythonExec.map(_.value)) + .getOrElse("python") + + logInfo(s"Python binary that will be called: $pythonExec") // Format python file paths before adding them to the PYTHONPATH val formattedPythonFile = formatPath(pythonFile) @@ -78,6 +104,8 @@ object PythonRunner { // Launch Python process val builder = new ProcessBuilder((Seq(pythonExec, formattedPythonFile) ++ otherArgs).asJava) val env = builder.environment() + // If there is a CondaEnvironment set up, initialise our process' env from that + maybeConda.foreach(_.initializeJavaEnvironment(env)) env.put("PYTHONPATH", pythonPath) // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 223c921810378..0928f569e4906 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -169,6 +169,30 @@ package object config { .stringConf .createOptional + private[spark] val CONDA_BOOTSTRAP_PACKAGES = ConfigBuilder("spark.conda.bootstrapPackages") + .doc("The packages that will be added to the conda environment. " + + "Only relevant when main class is CondaRunner.") + .stringConf + .toSequence + .createWithDefault(Nil) + + private[spark] val CONDA_CHANNEL_URLS = ConfigBuilder("spark.conda.channelUrls") + .doc("The URLs the Conda channels to use when resolving the conda packages. " + + "Only relevant when main class is CondaRunner.") + .stringConf + .toSequence + .createWithDefault(Nil) + + private[spark] val CONDA_BINARY_PATH = ConfigBuilder("spark.conda.binaryPath") + .doc("The location of the conda binary. Only relevant when main class is CondaRunner.") + .stringConf + .createOptional + + private[spark] val CONDA_VERBOSITY = ConfigBuilder("spark.conda.verbosity") + .doc("How many times to apply -v to conda. A number between 0 and 3, with 0 being default.") + .intConf + .createWithDefault(0) + // To limit memory usage, we only track information for a fixed number of tasks private[spark] val UI_RETAINED_TASKS = ConfigBuilder("spark.ui.retainedTasks") .intConf diff --git a/dev/run-tests.py b/dev/run-tests.py index 8ce30f54faace..b18e1d511260e 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -270,7 +270,7 @@ def exec_maven(mvn_args=()): kill_zinc_on_port(zinc_port) -def exec_sbt(sbt_args=(), exit_on_failure=True): +def exec_sbt(sbt_args=()): """Will call SBT in the current directory with the list of mvn_args passed in and returns the subprocess for any further processing""" @@ -295,8 +295,7 @@ def exec_sbt(sbt_args=(), exit_on_failure=True): retcode = sbt_proc.wait() if retcode != 0: - if exit_on_failure: - exit_from_command_with_retcode(sbt_cmd, retcode) + exit_from_command_with_retcode(sbt_cmd, retcode) return sbt_cmd, retcode @@ -406,7 +405,7 @@ def run_scala_tests_sbt(test_modules, test_profiles): print("[info] Running Spark tests using SBT with these arguments: ", " ".join(profiles_and_goals)) - exec_sbt(profiles_and_goals) + sbt_cmd, retcode = exec_sbt(profiles_and_goals) def run_scala_tests(build_tool, hadoop_version, test_modules, excluded_tags): diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 56b8c0b95e8a4..721e901be97f1 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -49,6 +49,9 @@ object MimaExcludes { // [SPARK-18537] Add a REST api to spark streaming ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.streaming.scheduler.StreamingListener.onStreamingStarted"), + // CondaRunner is meant to own the main() method then delegate to another method + ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.deploy.CondaRunner.main"), + // [SPARK-19148][SQL] do not expose the external table concept in Catalog ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.createTable"), diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 2961cda553d6a..c28a7709f56e2 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -188,7 +188,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, self._javaAccumulator = self._jvm.PythonAccumulatorV2(host, port) self._jsc.sc().register(self._javaAccumulator) - self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') + self.pythonExec = self._jvm.scala.Option.apply(os.environ.get("PYSPARK_PYTHON")) self.pythonVer = "%d.%d" % sys.version_info[:2] if sys.version_info < (2, 7): @@ -837,6 +837,20 @@ def addPyFile(self, path): import importlib importlib.invalidate_caches() + def addCondaPackages(self, *packages): + """ + Add a conda `package match specification + `_ for all tasks to be executed on + this SparkContext in the future. + """ + self._jsc.addCondaPackages(packages) + + def addCondaChannel(self, url): + self._jsc.sc().addCondaChannel(url) + + def _build_conda_instructions(self): + return self._jsc.sc().buildCondaInstructions() + def setCheckpointDir(self, dirName): """ Set the directory under which RDDs are going to be checkpointed. The diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 3c783ae541a1f..66d5c783dc367 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -115,6 +115,7 @@ def killChild(): # Import the classes used by PySpark java_import(gateway.jvm, "org.apache.spark.SparkConf") + java_import(gateway.jvm, "org.apache.spark.api.conda.*") java_import(gateway.jvm, "org.apache.spark.api.java.*") java_import(gateway.jvm, "org.apache.spark.api.python.*") java_import(gateway.jvm, "org.apache.spark.ml.python.*") diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index a5e6e2b054963..4a7b66268221a 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2370,8 +2370,9 @@ def _wrap_function(sc, func, deserializer, serializer, profiler=None): assert serializer, "serializer should not be empty" command = (func, profiler, deserializer, serializer) pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command) - return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec, - sc.pythonVer, broadcast_vars, sc._javaAccumulator) + return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, + sc._build_conda_instructions(), sc.pythonExec, sc.pythonVer, + broadcast_vars, sc._javaAccumulator) class PipelinedRDD(RDD): diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 376b86ea69bd4..ef1e0cc29dd6a 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1859,8 +1859,9 @@ def sort_array(col, asc=True): def _wrap_function(sc, func, returnType): command = (func, returnType) pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command) - return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec, - sc.pythonVer, broadcast_vars, sc._javaAccumulator) + return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, + sc._build_conda_instructions(), sc.pythonExec, sc.pythonVer, + broadcast_vars, sc._javaAccumulator) class UserDefinedFunction(object): diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index c6c87a9ea5555..b7704d0d5395d 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -2057,6 +2057,35 @@ def test_user_configuration(self): out, err = proc.communicate() self.assertEqual(0, proc.returncode, msg="Process failed with error:\n {0}".format(out)) + def test_conda(self): + """Submit and test a single script file via conda""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkContext + | + |sc = SparkContext() + |sc.addCondaPackages('numpy=1.11.1') + | + |# Ensure numpy is accessible on the driver + |import numpy + |arr = [1, 2, 3] + |def mul2(x): + | # Also ensure numpy accessible from executor + | assert numpy.version.version == "1.11.1" + | return x * 2 + |print(sc.parallelize(arr).map(mul2).collect()) + """) + props = self.createTempFile("properties", """ + |spark.conda.binaryPath {} + |spark.conda.channelUrls https://repo.continuum.io/pkgs/free + |spark.conda.bootstrapPackages python=3.5 + """.format(os.environ["CONDA_BIN"])) + proc = subprocess.Popen([self.sparkSubmit, + "--properties-file", props, + script], stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("[2, 4, 6]", out.decode('utf-8')) + class ContextTests(unittest.TestCase): diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala index 9c3b18e4ec5f3..7cc3075eb766c 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala @@ -137,7 +137,8 @@ abstract class BaseYarnClusterSuite extraClassPath: Seq[String] = Nil, extraJars: Seq[String] = Nil, extraConf: Map[String, String] = Map(), - extraEnv: Map[String, String] = Map()): SparkAppHandle.State = { + extraEnv: Map[String, String] = Map(), + timeoutDuration: FiniteDuration = 2.minutes): SparkAppHandle.State = { val deployMode = if (clientMode) "client" else "cluster" val propsFile = createConfFile(extraClassPath = extraClassPath, extraConf = extraConf) val env = Map("YARN_CONF_DIR" -> hadoopConfDir.getAbsolutePath()) ++ extraEnv @@ -167,7 +168,7 @@ abstract class BaseYarnClusterSuite val handle = launcher.startApplication() try { - eventually(timeout(2 minutes), interval(1 second)) { + eventually(timeout(timeoutDuration), interval(1 second)) { assert(handle.getState().isFinal()) } } finally { diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 99fb58a28934a..40dc2ec26a5fd 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -74,6 +74,36 @@ class YarnClusterSuite extends BaseYarnClusterSuite { | sc.stop() """.stripMargin + private val TEST_CONDA_PYFILE = """ + |import mod1, mod2 + |import sys + |from operator import add + | + |from pyspark import SparkConf , SparkContext + |if __name__ == "__main__": + | if len(sys.argv) != 2: + | print >> sys.stderr, "Usage: test.py [result file]" + | exit(-1) + | sc = SparkContext(conf=SparkConf()) + | + | sc.addCondaPackages('numpy=1.11.1') + | import numpy + | + | status = open(sys.argv[1],'w') + | + | numpy_multiply = lambda x: numpy.multiply(x, mod1.func() * mod2.func()) + | + | rdd = sc.parallelize(range(10)).map(numpy_multiply) + | cnt = rdd.count() + | if cnt == 10: + | result = "success" + | else: + | result = "failure" + | status.write(result) + | status.close() + | sc.stop() + """.stripMargin + private val TEST_PYMODULE = """ |def func(): | return 42 @@ -139,6 +169,14 @@ class YarnClusterSuite extends BaseYarnClusterSuite { testPySpark(false) } + test("run Python application within Conda in yarn-client mode") { + testCondaPySpark(true) + } + + test("run Python application within Conda in yarn-cluster mode") { + testCondaPySpark(false) + } + test("run Python application in yarn-cluster mode using " + " spark.yarn.appMasterEnv to override local envvar") { testPySpark( @@ -273,6 +311,55 @@ class YarnClusterSuite extends BaseYarnClusterSuite { checkResult(finalState, result) } + private def testCondaPySpark( + clientMode: Boolean, + extraEnv: Map[String, String] = Map()): Unit = { + val primaryPyFile = new File(tempDir, "test.py") + Files.write(TEST_CONDA_PYFILE, primaryPyFile, StandardCharsets.UTF_8) + + // When running tests, let's not assume the user has built the assembly module, which also + // creates the pyspark archive. Instead, let's use PYSPARK_ARCHIVES_PATH to point at the + // needed locations. + val sparkHome = sys.props("spark.test.home") + val pythonPath = Seq( + s"$sparkHome/python/lib/py4j-0.10.4-src.zip", + s"$sparkHome/python") + val extraEnvVars = Map( + "PYSPARK_ARCHIVES_PATH" -> pythonPath.map("local:" + _).mkString(File.pathSeparator), + "PYTHONPATH" -> pythonPath.mkString(File.pathSeparator)) ++ extraEnv + + val extraConf: Map[String, String] = Map( + "spark.conda.binaryPath" -> sys.env("CONDA_BIN"), + "spark.conda.channelUrls" -> "https://repo.continuum.io/pkgs/free", + "spark.conda.bootstrapPackages" -> "python=3.5" + ) + + val moduleDir = + if (clientMode) { + // In client-mode, .py files added with --py-files are not visible in the driver. + // This is something that the launcher library would have to handle. + tempDir + } else { + val subdir = new File(tempDir, "pyModules") + subdir.mkdir() + subdir + } + val pyModule = new File(moduleDir, "mod1.py") + Files.write(TEST_PYMODULE, pyModule, StandardCharsets.UTF_8) + + val mod2Archive = TestUtils.createJarWithFiles(Map("mod2.py" -> TEST_PYMODULE), moduleDir) + val pyFiles = Seq(pyModule.getAbsolutePath(), mod2Archive.getPath()).mkString(",") + val result = File.createTempFile("result", null, tempDir) + + val finalState = runSpark(clientMode, primaryPyFile.getAbsolutePath(), + sparkArgs = Seq("--py-files" -> pyFiles), + appArgs = Seq(result.getAbsolutePath()), + extraEnv = extraEnvVars, + extraConf = extraConf, + timeoutDuration = 4.minutes) // give it a bit longer + checkResult(finalState, result) + } + private def testUseClassPathFirst(clientMode: Boolean): Unit = { // Create a jar file that contains a different version of "test.resource". val originalJar = TestUtils.createJarWithFiles(Map("test.resource" -> "ORIGINAL"), tempDir) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala index 2a3d1cf0b298a..0d67796efe72b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala @@ -99,7 +99,8 @@ class DummyUDF extends PythonFunction( command = Array[Byte](), envVars = Map("" -> "").asJava, pythonIncludes = ArrayBuffer("").asJava, - pythonExec = "", + condaSetupInstructions = None, + pythonExec = None, pythonVer = "", broadcastVars = null, accumulator = null)