Skip to content
Closed
8 changes: 8 additions & 0 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,14 @@ class SparkEnv (
pythonWorkers.get(key).foreach(_.stopWorker(worker))
}
}

private[spark]
def releasePythonWorker(pythonExec: String, envVars: Map[String, String], worker: Socket) {
synchronized {
val key = (pythonExec, envVars)
pythonWorkers.get(key).foreach(_.releaseWorker(worker))
}
}
}

object SparkEnv extends Logging {
Expand Down
45 changes: 32 additions & 13 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.nio.charset.Charset
import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}

import scala.collection.JavaConversions._
import scala.collection.mutable
import scala.language.existentials
import scala.reflect.ClassTag
import scala.util.{Try, Success, Failure}
Expand Down Expand Up @@ -52,6 +53,7 @@ private[spark] class PythonRDD(
extends RDD[Array[Byte]](parent) {

val bufferSize = conf.getInt("spark.buffer.size", 65536)
val reuse_worker = conf.getBoolean("spark.python.reuse.worker", true)

override def getPartitions = parent.partitions

Expand All @@ -63,20 +65,17 @@ private[spark] class PythonRDD(
val localdir = env.blockManager.diskBlockManager.localDirs.map(
f => f.getPath()).mkString(",")
envVars += ("SPARK_LOCAL_DIRS" -> localdir) // it's also used in monitor thread
if (reuse_worker) {
envVars += ("SPARK_REUSE_WORKER" -> "1")
}
val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap)

// Start a thread to feed the process input from our parent's iterator
val writerThread = new WriterThread(env, worker, split, context)

context.addTaskCompletionListener { context =>
writerThread.shutdownOnTaskCompletion()

// Cleanup the worker socket. This will also cause the Python worker to exit.
try {
worker.close()
} catch {
case e: Exception => logWarning("Failed to close worker socket", e)
}
env.releasePythonWorker(pythonExec, envVars.toMap, worker)
}

writerThread.start()
Expand Down Expand Up @@ -195,18 +194,34 @@ private[spark] class PythonRDD(
PythonRDD.writeUTF(include, dataOut)
}
// Broadcast variables
dataOut.writeInt(broadcastVars.length)
val bids = PythonRDD.getWorkerBroadcasts(worker)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might be clearer to name these oldBids and newBids instead of bids and nbids.

val nbids = broadcastVars.map(_.id).toSet
// number of different broadcasts
val cnt = bids.diff(nbids).size + nbids.diff(bids).size
dataOut.writeInt(cnt)
for (bid <- bids) {
if (!nbids.contains(bid)) {
// remove the broadcast from worker
dataOut.writeLong(-bid)
bids.remove(bid)
}
}
for (broadcast <- broadcastVars) {
dataOut.writeLong(broadcast.id)
dataOut.writeInt(broadcast.value.length)
dataOut.write(broadcast.value)
if (!bids.contains(broadcast.id)) {
// send new broadcast
dataOut.writeLong(broadcast.id)
dataOut.writeInt(broadcast.value.length)
dataOut.write(broadcast.value)
bids.add(broadcast.id)
}
}
dataOut.flush()
// Serialized command:
dataOut.writeInt(command.length)
dataOut.write(command)
// Data values
PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut)
dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
dataOut.flush()
} catch {
case e: Exception if context.isCompleted || context.isInterrupted =>
Expand All @@ -216,8 +231,6 @@ private[spark] class PythonRDD(
// We must avoid throwing exceptions here, because the thread uncaught exception handler
// will kill the whole executor (see org.apache.spark.executor.Executor).
_exception = e
} finally {
Try(worker.shutdownOutput()) // kill Python worker process
}
}
}
Expand Down Expand Up @@ -278,6 +291,12 @@ private object SpecialLengths {
private[spark] object PythonRDD extends Logging {
val UTF8 = Charset.forName("UTF-8")

// remember the broadcasts sent to each worker
private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]()
private def getWorkerBroadcasts(worker: Socket) = {
workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this method have synchronization? I think getWorkerBroadcasts will be called from multiple threads.

}

/**
* Adapter for calling SparkContext#runJob from Python.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
var daemonPort: Int = 0
var daemonWorkers = new mutable.WeakHashMap[Socket, Int]()
var idleWorkers = new mutable.Queue[Socket]()

var simpleWorkers = new mutable.WeakHashMap[Socket, Process]()

Expand All @@ -51,6 +52,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String

def create(): Socket = {
if (useDaemon) {
if (idleWorkers.length > 0) {
return idleWorkers.dequeue()
}
createThroughDaemon()
} else {
createSimpleWorker()
Expand Down Expand Up @@ -235,6 +239,20 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
}
worker.close()
}

def releaseWorker(worker: Socket) {
if (useDaemon && envVars.get("SPARK_REUSE_WORKER").isDefined) {
idleWorkers.enqueue(worker)
} else {
// Cleanup the worker socket. This will also cause the Python worker to exit.
try {
worker.close()
} catch {
case e: Exception =>
logWarning("Failed to close worker socket", e)
}
}
}
}

private object PythonWorkerFactory {
Expand Down
10 changes: 10 additions & 0 deletions docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,16 @@ Apart from these, the following properties are also available, and may be useful
used during aggregation goes above this amount, it will spill the data into disks.
</td>
</tr>
<tr>
<td><code>spark.python.worker.reuse</code></td>
<td>true</td>
<td>
Reuse Python worker or not. If yes, it will use a fixed number of Python workers,
does not need to fork() a Python process for every tasks. It will be very useful
if there is large broadcast, then the broadcast will not be needed to transfered
from JVM to Python worker for every task.
</td>
</tr>
<tr>
<td><code>spark.executorEnv.[EnvironmentVariableName]</code></td>
<td>(none)</td>
Expand Down
44 changes: 20 additions & 24 deletions python/pyspark/daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import sys
import traceback
import time
import gc
from errno import EINTR, ECHILD, EAGAIN
from socket import AF_INET, SOCK_STREAM, SOMAXCONN
from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN
Expand All @@ -42,43 +43,24 @@ def worker(sock):
"""
Called by a worker process after the fork().
"""
# Redirect stdout to stderr
os.dup2(2, 1)
sys.stdout = sys.stderr # The sys.stdout object is different from file descriptor 1

signal.signal(SIGHUP, SIG_DFL)
signal.signal(SIGCHLD, SIG_DFL)
signal.signal(SIGTERM, SIG_DFL)

# Blocks until the socket is closed by draining the input stream
# until it raises an exception or returns EOF.
def waitSocketClose(sock):
try:
while True:
# Empty string is returned upon EOF (and only then).
if sock.recv(4096) == '':
return
except:
pass

# Read the socket using fdopen instead of socket.makefile() because the latter
# seems to be very slow; note that we need to dup() the file descriptor because
# otherwise writes also cause a seek that makes us miss data on the read side.
infile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
exit_code = 0
try:
# Acknowledge that the fork was successful
write_int(os.getpid(), outfile)
outfile.flush()
worker_main(infile, outfile)
except SystemExit as exc:
exit_code = exc.code
exit_code = compute_real_exit_code(exc.code)
finally:
outfile.flush()
# The Scala side will close the socket upon task completion.
waitSocketClose(sock)
os._exit(compute_real_exit_code(exit_code))
if exit_code:
os._exit(exit_code)


# Cleanup zombie children
Expand All @@ -102,6 +84,7 @@ def manager():
listen_sock.listen(max(1024, SOMAXCONN))
listen_host, listen_port = listen_sock.getsockname()
write_int(listen_port, sys.stdout)
sys.stdout.flush()

def shutdown(code):
signal.signal(SIGTERM, SIG_DFL)
Expand All @@ -114,8 +97,9 @@ def handle_sigterm(*args):
signal.signal(SIGTERM, handle_sigterm) # Gracefully exit on SIGTERM
signal.signal(SIGHUP, SIG_IGN) # Don't die on SIGHUP

reuse = os.environ.get("SPARK_REUSE_WORKER")

# Initialization complete
sys.stdout.close()
try:
while True:
try:
Expand Down Expand Up @@ -167,7 +151,19 @@ def handle_sigterm(*args):
# in child process
listen_sock.close()
try:
worker(sock)
# Acknowledge that the fork was successful
outfile = sock.makefile("w")
write_int(os.getpid(), outfile)
outfile.flush()
outfile.close()
while True:
worker(sock)
if not reuse:
# wait for closing
while sock.recv(1024):
pass
break
gc.collect()
except:
traceback.print_exc()
os._exit(1)
Expand Down
4 changes: 4 additions & 0 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ def _write_with_length(self, obj, stream):

def _read_with_length(self, stream):
length = read_int(stream)
if length == SpecialLengths.END_OF_DATA_SECTION:
raise EOFError
obj = stream.read(length)
if obj == "":
raise EOFError
Expand Down Expand Up @@ -431,6 +433,8 @@ class UTF8Deserializer(Serializer):

def loads(self, stream):
length = read_int(stream)
if length == SpecialLengths.END_OF_DATA_SECTION:
raise EOFError
return stream.read(length).decode('utf8')

def load_stream(self, stream):
Expand Down
7 changes: 5 additions & 2 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,11 @@ def main(infile, outfile):
ser = CompressedSerializer(pickleSer)
for _ in range(num_broadcast_variables):
bid = read_long(infile)
value = ser._read_with_length(infile)
_broadcastRegistry[bid] = Broadcast(bid, value)
if bid > 0:
value = ser._read_with_length(infile)
_broadcastRegistry[bid] = Broadcast(bid, value)
else:
_broadcastRegistry.pop(-bid, None)

command = pickleSer._read_with_length(infile)
(func, deserializer, serializer) = command
Expand Down
2 changes: 1 addition & 1 deletion python/run-tests
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ echo "Running PySpark tests. Output is in python/unit-tests.log."

# Try to test with Python 2.6, since that's the minimum version that we support:
if [ $(which python2.6) ]; then
export PYSPARK_PYTHON="python2.6"
export PYSPARK_PYTHON="pypy"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this change got pulled in by accident?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes:)

fi

echo "Testing with Python version:"
Expand Down