Skip to content

Commit 493d6b0

Browse files
Count backward eps in numEps hint. (#23)
Signed-off-by: Peter Rudenko <[email protected]> Signed-off-by: Peter Rudenko <[email protected]>
1 parent fb18bc5 commit 493d6b0

1 file changed

Lines changed: 2 additions & 1 deletion

File tree

src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo
106106
}
107107

108108
val numEndpoints = ucxShuffleConf.numWorkers *
109-
ucxShuffleConf.getSparkConf.getInt("spark.executor.instances", 1)
109+
ucxShuffleConf.getSparkConf.getInt("spark.executor.instances", 1) *
110+
ucxShuffleConf.numListenerThreads // Each listener thread creates backward endpoint
110111
logInfo(s"Creating UCX context with an estimated number of endpoints: $numEndpoints")
111112

112113
val params = new UcpParams().requestAmFeature().setMtWorkersShared(true).setEstimatedNumEps(numEndpoints)

0 commit comments

Comments
 (0)