Skip to content

Commit 5af8780

Browse files
parthchandraGitHub Enterprise
authored andcommitted
rdar://90787805 Make spark driver endpoints extensible to be used by awsappleconnect-spark-utils (apache#1453)
SPARK-38954: Implement sharing of cloud credentials among driver and executors
1 parent 722b03b commit 5af8780

19 files changed

Lines changed: 841 additions & 106 deletions

core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.deploy
1919

20-
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream, File, IOException}
20+
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream, File, IOException, ObjectInputStream, ObjectOutputStream}
2121
import java.security.PrivilegedExceptionAction
2222
import java.text.DateFormat
2323
import java.util.{Arrays, Date, Locale}
@@ -37,6 +37,7 @@ import org.apache.hadoop.security.token.{Token, TokenIdentifier}
3737
import org.apache.hadoop.security.token.delegation.AbstractDelegationTokenIdentifier
3838

3939
import org.apache.spark.{SparkConf, SparkException}
40+
import org.apache.spark.deploy.security.cloud.{CloudCredentials, CloudCredentialsManager}
4041
import org.apache.spark.internal.Logging
4142
import org.apache.spark.internal.config.BUFFER_SIZE
4243
import org.apache.spark.util.Utils
@@ -159,6 +160,17 @@ private[spark] class SparkHadoopUtil extends Logging {
159160
addCurrentUserCredentials(creds)
160161
}
161162

163+
/**
164+
* Add or overwrite current cloud credentials
165+
*/
166+
private[spark] def updateCloudCredentials(credentials: Array[Byte], sparkConf: SparkConf): Unit =
167+
{
168+
val creds = deserializeCloudCredentials(credentials)
169+
logInfo(s"Updating cloud credentials for ${creds.serviceName}.")
170+
sparkConf.set(CloudCredentialsManager.cloudCredentialsConfig.format(creds.serviceName),
171+
creds.credentials)
172+
}
173+
162174
/**
163175
* Returns a function that can be called to find Hadoop FileSystem bytes read. If
164176
* getFSBytesReadOnThreadCallback is called from thread r at time t, the returned callback will
@@ -386,6 +398,21 @@ private[spark] class SparkHadoopUtil extends Logging {
386398
creds
387399
}
388400

401+
def serializeCloudCredentials(creds: CloudCredentials): Array[Byte] = {
402+
val byteStream = new ByteArrayOutputStream
403+
val objStream = new ObjectOutputStream(byteStream)
404+
objStream.writeObject(creds)
405+
objStream.close()
406+
byteStream.toByteArray
407+
}
408+
409+
def deserializeCloudCredentials(credentialsBytes: Array[Byte]): CloudCredentials = {
410+
val objStream = new ObjectInputStream(new ByteArrayInputStream(credentialsBytes))
411+
val creds = objStream.readObject().asInstanceOf[CloudCredentials]
412+
objStream.close()
413+
creds
414+
}
415+
389416
def isProxyUser(ugi: UserGroupInformation): Boolean = {
390417
ugi.getAuthenticationMethod() == UserGroupInformation.AuthenticationMethod.PROXY
391418
}

core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala

Lines changed: 29 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,19 @@ import java.io.File
2121
import java.net.URI
2222
import java.security.PrivilegedExceptionAction
2323
import java.util.ServiceLoader
24-
import java.util.concurrent.{ScheduledExecutorService, TimeUnit}
25-
26-
import scala.collection.mutable
24+
import java.util.concurrent.TimeUnit
2725

2826
import org.apache.hadoop.conf.Configuration
2927
import org.apache.hadoop.security.{Credentials, UserGroupInformation}
3028

3129
import org.apache.spark.SparkConf
3230
import org.apache.spark.deploy.SparkHadoopUtil
33-
import org.apache.spark.internal.Logging
3431
import org.apache.spark.internal.config._
3532
import org.apache.spark.rpc.RpcEndpointRef
3633
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.UpdateDelegationTokens
3734
import org.apache.spark.security.HadoopDelegationTokenProvider
3835
import org.apache.spark.ui.UIUtils
39-
import org.apache.spark.util.{ThreadUtils, Utils}
36+
import org.apache.spark.util.Utils
4037

4138
/**
4239
* Manager for delegation tokens in a Spark application.
@@ -61,9 +58,11 @@ import org.apache.spark.util.{ThreadUtils, Utils}
6158
* generated.
6259
*/
6360
private[spark] class HadoopDelegationTokenManager(
64-
protected val sparkConf: SparkConf,
65-
protected val hadoopConf: Configuration,
66-
protected val schedulerRef: RpcEndpointRef) extends Logging {
61+
override protected val sparkConf: SparkConf,
62+
override protected val hadoopConf: Configuration,
63+
override protected val schedulerRef: RpcEndpointRef)
64+
extends ServiceCredentialsManager[HadoopDelegationTokenProvider](
65+
sparkConf, hadoopConf, schedulerRef) {
6766

6867
private val principal = sparkConf.get(PRINCIPAL).orNull
6968

@@ -74,11 +73,15 @@ private[spark] class HadoopDelegationTokenManager(
7473
require((principal == null) == (keytab == null),
7574
"Both principal and keytab must be defined, or neither.")
7675

77-
private val delegationTokenProviders = loadProviders()
78-
logDebug("Using the following builtin delegation token providers: " +
79-
s"${delegationTokenProviders.keys.mkString(", ")}.")
76+
val delegationTokenProviders: Map[String, HadoopDelegationTokenProvider] = credentialsProviders
77+
.asInstanceOf[Map[String, HadoopDelegationTokenProvider]]
78+
79+
def credentialsType: String = "Delegation token"
8080

81-
private var renewalExecutor: ScheduledExecutorService = _
81+
def credentialsConfig: ServiceCredentialsConfig = HadoopDelegationTokenManager
82+
83+
def getProviderLoader: ServiceLoader[HadoopDelegationTokenProvider] =
84+
ServiceLoader.load(classOf[HadoopDelegationTokenProvider], Utils.getContextOrSparkClassLoader)
8285

8386
/** @return Whether delegation token renewal is enabled. */
8487
def renewalEnabled: Boolean = sparkConf.get(KERBEROS_RENEWAL_CREDENTIALS) match {
@@ -97,14 +100,15 @@ private[spark] class HadoopDelegationTokenManager(
97100
*
98101
* @return New set of delegation tokens created for the configured principal.
99102
*/
100-
def start(): Array[Byte] = {
101-
require(renewalEnabled, "Token renewal must be enabled to start the renewer.")
102-
require(schedulerRef != null, "Token renewal requires a scheduler endpoint.")
103-
renewalExecutor =
104-
ThreadUtils.newDaemonSingleThreadScheduledExecutor("Credential Renewal Thread")
103+
override def start(): Array[Byte] = {
104+
super.start()
105+
}
105106

106-
val ugi = UserGroupInformation.getCurrentUser()
107-
if (ugi.isFromKeytab()) {
107+
def updateCredentialsGrantingTicket(): Unit = {
108+
require(renewalExecutor != null,
109+
"Renewal executor must be initialized before updating TGT.")
110+
val ugi = UserGroupInformation.getCurrentUser
111+
if (ugi.isFromKeytab) {
108112
// In Hadoop 2.x, renewal of the keytab-based login seems to be automatic, but in Hadoop 3.x,
109113
// it is configurable (see hadoop.kerberos.keytab.login.autorenewal.enabled, added in
110114
// HADOOP-9567). This task will make sure that the user stays logged in regardless of that
@@ -119,14 +123,10 @@ private[spark] class HadoopDelegationTokenManager(
119123
renewalExecutor.scheduleAtFixedRate(tgtRenewalTask, tgtRenewalPeriod, tgtRenewalPeriod,
120124
TimeUnit.SECONDS)
121125
}
122-
123-
updateTokensTask()
124126
}
125127

126-
def stop(): Unit = {
127-
if (renewalExecutor != null) {
128-
renewalExecutor.shutdownNow()
129-
}
128+
def updateCredentialsTask(): Array[Byte] = {
129+
updateTokensTask()
130130
}
131131

132132
/**
@@ -171,23 +171,6 @@ private[spark] class HadoopDelegationTokenManager(
171171
(creds, nextRenewal)
172172
}
173173

174-
// Visible for testing.
175-
def isProviderLoaded(serviceName: String): Boolean = {
176-
delegationTokenProviders.contains(serviceName)
177-
}
178-
179-
private def scheduleRenewal(delay: Long): Unit = {
180-
val _delay = math.max(0, delay)
181-
logInfo(s"Scheduling renewal in ${UIUtils.formatDuration(_delay)}.")
182-
183-
val renewalTask = new Runnable() {
184-
override def run(): Unit = {
185-
updateTokensTask()
186-
}
187-
}
188-
renewalExecutor.schedule(renewalTask, _delay, TimeUnit.MILLISECONDS)
189-
}
190-
191174
/**
192175
* Periodic task to login to the KDC and create new delegation tokens. Re-schedules itself
193176
* to fetch the next set of tokens when needed.
@@ -224,14 +207,7 @@ private[spark] class HadoopDelegationTokenManager(
224207
ugi.doAs(new PrivilegedExceptionAction[Credentials]() {
225208
override def run(): Credentials = {
226209
val (creds, nextRenewal) = obtainDelegationTokens()
227-
228-
// Calculate the time when new credentials should be created, based on the configured
229-
// ratio.
230-
val now = System.currentTimeMillis
231-
val ratio = sparkConf.get(CREDENTIALS_RENEWAL_INTERVAL_RATIO)
232-
val delay = (ratio * (nextRenewal - now)).toLong
233-
logInfo(s"Calculated delay on renewal is $delay, based on next renewal $nextRenewal " +
234-
s"and the ratio $ratio, and current time $now")
210+
val delay = calculateNextRenewalInterval(nextRenewal)
235211
scheduleRenewal(delay)
236212
creds
237213
}
@@ -255,57 +231,12 @@ private[spark] class HadoopDelegationTokenManager(
255231
UserGroupInformation.getCurrentUser()
256232
}
257233
}
258-
259-
private def loadProviders(): Map[String, HadoopDelegationTokenProvider] = {
260-
val loader = ServiceLoader.load(classOf[HadoopDelegationTokenProvider],
261-
Utils.getContextOrSparkClassLoader)
262-
val providers = mutable.ArrayBuffer[HadoopDelegationTokenProvider]()
263-
264-
val iterator = loader.iterator
265-
while (iterator.hasNext) {
266-
try {
267-
providers += iterator.next
268-
} catch {
269-
case t: Throwable =>
270-
logDebug(s"Failed to load built in provider.", t)
271-
}
272-
}
273-
274-
// Filter out providers for which spark.security.credentials.{service}.enabled is false.
275-
providers
276-
.filter { p => HadoopDelegationTokenManager.isServiceEnabled(sparkConf, p.serviceName) }
277-
.map { p => (p.serviceName, p) }
278-
.toMap
279-
}
280234
}
281235

282-
private[spark] object HadoopDelegationTokenManager extends Logging {
283-
private val providerEnabledConfig = "spark.security.credentials.%s.enabled"
236+
private[spark] object HadoopDelegationTokenManager extends ServiceCredentialsConfig {
237+
override def providerEnabledConfig: String = "spark.security.credentials.%s.enabled"
284238

285-
private val deprecatedProviderEnabledConfigs = List(
239+
override def deprecatedProviderEnabledConfigs: List[String] = List(
286240
"spark.yarn.security.tokens.%s.enabled",
287241
"spark.yarn.security.credentials.%s.enabled")
288-
289-
def isServiceEnabled(sparkConf: SparkConf, serviceName: String): Boolean = {
290-
val key = providerEnabledConfig.format(serviceName)
291-
292-
deprecatedProviderEnabledConfigs.foreach { pattern =>
293-
val deprecatedKey = pattern.format(serviceName)
294-
if (sparkConf.contains(deprecatedKey)) {
295-
logWarning(s"${deprecatedKey} is deprecated. Please use ${key} instead.")
296-
}
297-
}
298-
299-
val isEnabledDeprecated = deprecatedProviderEnabledConfigs.forall { pattern =>
300-
sparkConf
301-
.getOption(pattern.format(serviceName))
302-
.map(_.toBoolean)
303-
.getOrElse(true)
304-
}
305-
306-
sparkConf
307-
.getOption(key)
308-
.map(_.toBoolean)
309-
.getOrElse(isEnabledDeprecated)
310-
}
311242
}
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.deploy.security
19+
20+
import org.apache.spark.SparkConf
21+
import org.apache.spark.internal.Logging
22+
23+
/**
24+
* Helper class to define and access configuration parameters for ServiceCredentialsProviders
25+
*/
26+
private[spark] trait ServiceCredentialsConfig extends Logging{
27+
28+
/**
29+
* Configuration param for enabling the credentials provider
30+
* @return a configuration string
31+
*/
32+
def providerEnabledConfig : String
33+
/**
34+
* Deprecated configuration params for enabling the credentials provider
35+
* @return a list of configuration strings. May be empty
36+
*/
37+
def deprecatedProviderEnabledConfigs : List[String]
38+
39+
def isServiceEnabled(sparkConf: SparkConf, serviceName: String): Boolean = {
40+
val key = providerEnabledConfig.format(serviceName)
41+
42+
deprecatedProviderEnabledConfigs.foreach { pattern =>
43+
val deprecatedKey = pattern.format(serviceName)
44+
if (sparkConf.contains(deprecatedKey)) {
45+
logWarning(s"${deprecatedKey} is deprecated. Please use ${key} instead.")
46+
}
47+
}
48+
49+
val isEnabledDeprecated = deprecatedProviderEnabledConfigs.forall { pattern =>
50+
sparkConf
51+
.getOption(pattern.format(serviceName))
52+
.map(_.toBoolean)
53+
.getOrElse(true)
54+
}
55+
56+
sparkConf
57+
.getOption(key)
58+
.map(_.toBoolean)
59+
.getOrElse(isEnabledDeprecated)
60+
}
61+
62+
}
63+
64+

0 commit comments

Comments
 (0)