Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark

import java.io.File
import java.net.Socket

import scala.collection.JavaConversions._
import scala.collection.mutable
Expand Down Expand Up @@ -102,10 +103,10 @@ class SparkEnv (
}

private[spark]
def destroyPythonWorker(pythonExec: String, envVars: Map[String, String]) {
def destroyPythonWorker(pythonExec: String, envVars: Map[String, String], worker: Socket) {
synchronized {
val key = (pythonExec, envVars)
pythonWorkers(key).stop()
pythonWorkers.get(key).foreach(_.stopWorker(worker))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ private[spark] class PythonRDD(
if (!context.completed) {
try {
logWarning("Incomplete task interrupted: Attempting to kill Python Worker")
env.destroyPythonWorker(pythonExec, envVars.toMap)
env.destroyPythonWorker(pythonExec, envVars.toMap, worker)
} catch {
case e: Exception =>
logError("Exception when trying to kill worker", e)
Expand Down Expand Up @@ -685,9 +685,8 @@ private[spark] object PythonRDD extends Logging {

/**
* Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions).
* This function is outdated, PySpark does not use it anymore
*/
@deprecated
@deprecated("PySpark does not use it anymore", "1.1")
def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = {
pyRDD.rdd.mapPartitions { iter =>
val unpickle = new Unpickler
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@

package org.apache.spark.api.python

import java.lang.Runtime
import java.io.{DataInputStream, InputStream, OutputStreamWriter}
import java.net.{InetAddress, ServerSocket, Socket, SocketException}

import scala.collection.mutable
import scala.collection.JavaConversions._

import org.apache.spark._
Expand All @@ -39,6 +41,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
var daemon: Process = null
val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
var daemonPort: Int = 0
var daemonWorkers = new mutable.WeakHashMap[Socket, Int]()

var simpleWorkers = new mutable.WeakHashMap[Socket, Process]()

val pythonPath = PythonUtils.mergePythonPaths(
PythonUtils.sparkPythonPath,
Expand All @@ -65,10 +70,11 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
// Attempt to connect, restart and retry once if it fails
try {
val socket = new Socket(daemonHost, daemonPort)
val launchStatus = new DataInputStream(socket.getInputStream).readInt()
if (launchStatus != 0) {
val pid = new DataInputStream(socket.getInputStream).readInt()
if (pid < 0) {
throw new IllegalStateException("Python daemon failed to launch worker")
}
daemonWorkers.put(socket, pid)
socket
} catch {
case exc: SocketException =>
Expand Down Expand Up @@ -107,7 +113,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
// Wait for it to connect to our socket
serverSocket.setSoTimeout(10000)
try {
return serverSocket.accept()
val socket = serverSocket.accept()
simpleWorkers.put(socket, pb)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like simpleWorkers is declared as a map of Process but here you're storing a ProcessBuilder; IntelliJ displays this as an error, but it still seems to compile. Any idea what's going on here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As an experiment, I added the line

simpleWorkers(socket).destroy()

and, as expected, this results in a java.lang.ProcessBuilder cannot be cast to java.lang.Process error.

The compiler should have prevented this, so I think we've found a compiler bug.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, this looks like a compiler bug. This file compiles in 2.10.4 and gives the expected error in 2.11.2:

import scala.collection.JavaConversions._
import java.lang.Process
import java.lang.ProcessBuilder

object ImplicitBug {
  def main(args: Seq[String]) {
    val simpleWorkers = new scala.collection.mutable.WeakHashMap[Int, Process]()
    val pb = new ProcessBuilder()
    simpleWorkers.put(1, pb)
  }
}

For the curious, here's the actual implicit conversions that led to this bug (run /scala-2.10.4/bin/scalac -Xprint:typer ImplicitBug.scala to get this output):

[[syntax trees at end of                     typer]] // ImplicitBug.scala
package <empty> {
  import scala.collection.JavaConversions._;
  import java.lang.Process;
  import java.lang.ProcessBuilder;
  object ImplicitBug extends scala.AnyRef {
    def <init>(): ImplicitBug.type = {
      ImplicitBug.super.<init>();
      ()
    };
    def main(args: Seq[String]): Unit = {
      val simpleWorkers: scala.collection.mutable.WeakHashMap[Int,Process] = new scala.collection.mutable.WeakHashMap[Int,Process]();
      val pb: ProcessBuilder = new java.lang.ProcessBuilder();
      {
        // Somehow it became a map with `Object` values:
        scala.collection.JavaConversions.mapAsJavaMap[Int, Object](simpleWorkers).put(1, pb);
        ()
      }
    }
  }
}

return socket
} catch {
case e: Exception =>
throw new SparkException("Python worker did not connect back in time", e)
Expand Down Expand Up @@ -189,19 +197,34 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String

private def stopDaemon() {
synchronized {
// Request shutdown of existing daemon by sending SIGTERM
if (daemon != null) {
daemon.destroy()
}
if (useDaemon) {
// Request shutdown of existing daemon by sending SIGTERM
if (daemon != null) {
daemon.destroy()
}

daemon = null
daemonPort = 0
daemon = null
daemonPort = 0
} else {
simpleWorkers.mapValues(_.destroy())
}
}
}

def stop() {
stopDaemon()
}

def stopWorker(worker: Socket) {
if (useDaemon) {
daemonWorkers.get(worker).foreach {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other accesses of daemonWorkers are guarded by synchronized blocks; does this access also need synchronization? It looks like calls to stopWorker() only occur from destroyPythonWorker(), which is synchronized using the SparkEnv object, but that's a different lock. To be on the safe side, we should probably add synchronized here unless there's a good reason not to.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I think the current synchronization is fine: every call of PythonWorkerFactory's public methods is guarded by SparkEnv's lock.

pid => Runtime.getRuntime.exec("kill " + pid.toString)
}
} else {
simpleWorkers.get(worker).foreach(_.destroy())
}
worker.close()
}
}

private object PythonWorkerFactory {
Expand Down
13 changes: 9 additions & 4 deletions python/pyspark/daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def waitSocketClose(sock):
outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
exit_code = 0
try:
write_int(0, outfile) # Acknowledge that the fork was successful
# Acknowledge that the fork was successful
write_int(os.getpid(), outfile)
outfile.flush()
worker_main(infile, outfile)
except SystemExit as exc:
Expand Down Expand Up @@ -131,8 +132,8 @@ def handle_sigchld(*args):
sock, addr = listen_sock.accept()
# Launch a worker process
try:
fork_return_code = os.fork()
if fork_return_code == 0:
pid = os.fork()
if pid == 0:
listen_sock.close()
try:
worker(sock)
Expand All @@ -141,13 +142,17 @@ def handle_sigchld(*args):
os._exit(1)
else:
os._exit(0)
else:
elif pid > 0:
sock.close()
else:
raise OSError("fork failed")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that os.fork() already handles negative return values by throwing OSError, so I think this else block is dead code: https://docs.python.org/2/library/os.html#os.fork


except OSError as e:
print >> sys.stderr, "Daemon failed to fork PySpark worker: %s" % e
outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
write_int(-1, outfile) # Signal that the fork failed
outfile.flush()
outfile.close()
sock.close()
finally:
shutdown(1)
Expand Down