-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-3030] [PySpark] Reuse Python worker #2259
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
8d2f08c
6123d0f
ace2917
583716e
e0131a2
6325fc1
8911f44
ac3206e
7abb224
760ab1f
3133a60
cf1c55e
3939f20
f11f617
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,6 +23,7 @@ import java.nio.charset.Charset | |
| import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections} | ||
|
|
||
| import scala.collection.JavaConversions._ | ||
| import scala.collection.mutable | ||
| import scala.language.existentials | ||
| import scala.reflect.ClassTag | ||
| import scala.util.{Try, Success, Failure} | ||
|
|
@@ -52,6 +53,7 @@ private[spark] class PythonRDD( | |
| extends RDD[Array[Byte]](parent) { | ||
|
|
||
| val bufferSize = conf.getInt("spark.buffer.size", 65536) | ||
| val reuse_worker = conf.getBoolean("spark.python.reuse.worker", true) | ||
|
|
||
| override def getPartitions = parent.partitions | ||
|
|
||
|
|
@@ -63,20 +65,17 @@ private[spark] class PythonRDD( | |
| val localdir = env.blockManager.diskBlockManager.localDirs.map( | ||
| f => f.getPath()).mkString(",") | ||
| envVars += ("SPARK_LOCAL_DIRS" -> localdir) // it's also used in monitor thread | ||
| if (reuse_worker) { | ||
| envVars += ("SPARK_REUSE_WORKER" -> "1") | ||
| } | ||
| val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap) | ||
|
|
||
| // Start a thread to feed the process input from our parent's iterator | ||
| val writerThread = new WriterThread(env, worker, split, context) | ||
|
|
||
| context.addTaskCompletionListener { context => | ||
| writerThread.shutdownOnTaskCompletion() | ||
|
|
||
| // Cleanup the worker socket. This will also cause the Python worker to exit. | ||
| try { | ||
| worker.close() | ||
| } catch { | ||
| case e: Exception => logWarning("Failed to close worker socket", e) | ||
| } | ||
| env.releasePythonWorker(pythonExec, envVars.toMap, worker) | ||
| } | ||
|
|
||
| writerThread.start() | ||
|
|
@@ -195,18 +194,34 @@ private[spark] class PythonRDD( | |
| PythonRDD.writeUTF(include, dataOut) | ||
| } | ||
| // Broadcast variables | ||
| dataOut.writeInt(broadcastVars.length) | ||
| val bids = PythonRDD.getWorkerBroadcasts(worker) | ||
| val nbids = broadcastVars.map(_.id).toSet | ||
| // number of different broadcasts | ||
| val cnt = bids.diff(nbids).size + nbids.diff(bids).size | ||
| dataOut.writeInt(cnt) | ||
| for (bid <- bids) { | ||
| if (!nbids.contains(bid)) { | ||
| // remove the broadcast from worker | ||
| dataOut.writeLong(-bid) | ||
| bids.remove(bid) | ||
| } | ||
| } | ||
| for (broadcast <- broadcastVars) { | ||
| dataOut.writeLong(broadcast.id) | ||
| dataOut.writeInt(broadcast.value.length) | ||
| dataOut.write(broadcast.value) | ||
| if (!bids.contains(broadcast.id)) { | ||
| // send new broadcast | ||
| dataOut.writeLong(broadcast.id) | ||
| dataOut.writeInt(broadcast.value.length) | ||
| dataOut.write(broadcast.value) | ||
| bids.add(broadcast.id) | ||
| } | ||
| } | ||
| dataOut.flush() | ||
| // Serialized command: | ||
| dataOut.writeInt(command.length) | ||
| dataOut.write(command) | ||
| // Data values | ||
| PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut) | ||
| dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) | ||
| dataOut.flush() | ||
| } catch { | ||
| case e: Exception if context.isCompleted || context.isInterrupted => | ||
|
|
@@ -216,8 +231,6 @@ private[spark] class PythonRDD( | |
| // We must avoid throwing exceptions here, because the thread uncaught exception handler | ||
| // will kill the whole executor (see org.apache.spark.executor.Executor). | ||
| _exception = e | ||
| } finally { | ||
| Try(worker.shutdownOutput()) // kill Python worker process | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -278,6 +291,12 @@ private object SpecialLengths { | |
| private[spark] object PythonRDD extends Logging { | ||
| val UTF8 = Charset.forName("UTF-8") | ||
|
|
||
| // remember the broadcasts sent to each worker | ||
| private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]() | ||
| private def getWorkerBroadcasts(worker: Socket) = { | ||
| workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]()) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this method have synchronization? I think getWorkerBroadcasts will be called from multiple threads. |
||
| } | ||
|
|
||
| /** | ||
| * Adapter for calling SparkContext#runJob from Python. | ||
| * | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -50,7 +50,7 @@ echo "Running PySpark tests. Output is in python/unit-tests.log." | |
|
|
||
| # Try to test with Python 2.6, since that's the minimum version that we support: | ||
| if [ $(which python2.6) ]; then | ||
| export PYSPARK_PYTHON="python2.6" | ||
| export PYSPARK_PYTHON="pypy" | ||
|
||
| fi | ||
|
|
||
| echo "Testing with Python version:" | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it might be clearer to name these
oldBidsandnewBidsinstead ofbidsandnbids.