@@ -25,14 +25,16 @@ public sealed class ActiveDirectoryAuthenticationProvider : SqlAuthenticationPro
2525 /// The purpose of this cache is to allow re-use of Access Tokens fetched for a user interactively or with any other mode
2626 /// to avoid interactive authentication request every-time, within application scope making use of MSAL's userTokenCache.
2727 /// </summary>
28- private static ConcurrentDictionary < PublicClientAppKey , IPublicClientApplication > s_pcaMap
29- = new ConcurrentDictionary < PublicClientAppKey , IPublicClientApplication > ( ) ;
3028 private static readonly MemoryCache s_accountPwCache = new ( nameof ( ActiveDirectoryAuthenticationProvider ) ) ;
29+ private static readonly ConcurrentDictionary < PublicClientAppKey , IPublicClientApplication > s_pcaMap = new ( ) ;
30+ private static readonly ConcurrentDictionary < TokenCredentialKey , TokenCredentialData > s_tokenCredentialMap = new ( ) ;
31+ private static SemaphoreSlim s_pcaMapModifierSemaphore = new ( 1 , 1 ) ;
32+ private static SemaphoreSlim s_tokenCredentialMapModifierSemaphore = new ( 1 , 1 ) ;
3133 private static readonly int s_accountPwCacheTtlInHours = 2 ;
3234 private static readonly string s_nativeClientRedirectUri = "https://login.microsoftonline.com/common/oauth2/nativeclient" ;
3335 private static readonly string s_defaultScopeSuffix = "/.default" ;
3436 private readonly string _type = typeof ( ActiveDirectoryAuthenticationProvider ) . Name ;
35- private readonly SqlClientLogger _logger = new SqlClientLogger ( ) ;
37+ private readonly SqlClientLogger _logger = new ( ) ;
3638 private Func < DeviceCodeResult , Task > _deviceCodeFlowCallback ;
3739 private ICustomWebUi _customWebUI = null ;
3840 private readonly string _applicationClientId = ActiveDirectoryAuthentication . AdoClientId ;
@@ -66,6 +68,11 @@ public static void ClearUserTokenCache()
6668 {
6769 s_pcaMap . Clear ( ) ;
6870 }
71+
72+ if ( ! s_tokenCredentialMap . IsEmpty )
73+ {
74+ s_tokenCredentialMap . Clear ( ) ;
75+ }
6976 }
7077
7178 /// <include file='../../../../../../doc/snippets/Microsoft.Data.SqlClient/ActiveDirectoryAuthenticationProvider.xml' path='docs/members[@name="ActiveDirectoryAuthenticationProvider"]/SetDeviceCodeFlowCallback/*'/>
@@ -145,50 +152,40 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
145152 * More information: https://docs.microsoft.com/azure/active-directory/develop/msal-client-application-configuration
146153 **/
147154
148- int seperatorIndex = parameters . Authority . LastIndexOf ( '/' ) ;
149- string authority = parameters . Authority . Remove ( seperatorIndex + 1 ) ;
150- string audience = parameters . Authority . Substring ( seperatorIndex + 1 ) ;
155+ int separatorIndex = parameters . Authority . LastIndexOf ( '/' ) ;
156+ string authority = parameters . Authority . Remove ( separatorIndex + 1 ) ;
157+ string audience = parameters . Authority . Substring ( separatorIndex + 1 ) ;
151158 string clientId = string . IsNullOrWhiteSpace ( parameters . UserId ) ? null : parameters . UserId ;
152159
153160 if ( parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryDefault )
154161 {
155- DefaultAzureCredentialOptions defaultAzureCredentialOptions = new ( )
156- {
157- AuthorityHost = new Uri ( authority ) ,
158- SharedTokenCacheTenantId = audience ,
159- VisualStudioCodeTenantId = audience ,
160- VisualStudioTenantId = audience ,
161- ExcludeInteractiveBrowserCredential = true // Force disabled, even though it's disabled by default to respect driver specifications.
162- } ;
163-
164- // Optionally set clientId when available
165- if ( clientId is not null )
166- {
167- defaultAzureCredentialOptions . ManagedIdentityClientId = clientId ;
168- defaultAzureCredentialOptions . SharedTokenCacheUsername = clientId ;
169- }
170- AccessToken accessToken = await new DefaultAzureCredential ( defaultAzureCredentialOptions ) . GetTokenAsync ( tokenRequestContext , cts . Token ) . ConfigureAwait ( false ) ;
162+ // Cache DefaultAzureCredenial based on scope, authority, audience, and clientId
163+ TokenCredentialKey tokenCredentialKey = new ( typeof ( DefaultAzureCredential ) , authority , scope , audience , clientId ) ;
164+ AccessToken accessToken = await GetTokenAsync ( tokenCredentialKey , string . Empty , tokenRequestContext , cts . Token ) . ConfigureAwait ( false ) ;
171165 SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token for Default auth mode. Expiry Time: {0}" , accessToken . ExpiresOn ) ;
172166 return new SqlAuthenticationToken ( accessToken . Token , accessToken . ExpiresOn ) ;
173167 }
174168
175- TokenCredentialOptions tokenCredentialOptions = new TokenCredentialOptions ( ) { AuthorityHost = new Uri ( authority ) } ;
169+ TokenCredentialOptions tokenCredentialOptions = new ( ) { AuthorityHost = new Uri ( authority ) } ;
176170
177171 if ( parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryManagedIdentity || parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryMSI )
178172 {
179- AccessToken accessToken = await new ManagedIdentityCredential ( clientId , tokenCredentialOptions ) . GetTokenAsync ( tokenRequestContext , cts . Token ) . ConfigureAwait ( false ) ;
173+ // Cache ManagedIdentityCredential based on scope, authority, and clientId
174+ TokenCredentialKey tokenCredentialKey = new ( typeof ( ManagedIdentityCredential ) , authority , scope , string . Empty , clientId ) ;
175+ AccessToken accessToken = await GetTokenAsync ( tokenCredentialKey , string . Empty , tokenRequestContext , cts . Token ) . ConfigureAwait ( false ) ;
180176 SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token for Managed Identity auth mode. Expiry Time: {0}" , accessToken . ExpiresOn ) ;
181177 return new SqlAuthenticationToken ( accessToken . Token , accessToken . ExpiresOn ) ;
182178 }
183179
184180 AuthenticationResult result = null ;
185181 if ( parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryServicePrincipal )
186182 {
187- AccessToken accessToken = await new ClientSecretCredential ( audience , parameters . UserId , parameters . Password , tokenCredentialOptions ) . GetTokenAsync ( tokenRequestContext , cts . Token ) . ConfigureAwait ( false ) ;
183+ // Cache ClientSecretCredential based on scope, authority, audience, and clientId
184+ TokenCredentialKey tokenCredentialKey = new ( typeof ( ClientSecretCredential ) , authority , scope , audience , clientId ) ;
185+ AccessToken accessToken = await GetTokenAsync ( tokenCredentialKey , parameters . Password , tokenRequestContext , cts . Token ) . ConfigureAwait ( false ) ;
188186 SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token for Active Directory Service Principal auth mode. Expiry Time: {0}" , accessToken . ExpiresOn ) ;
189187 return new SqlAuthenticationToken ( accessToken . Token , accessToken . ExpiresOn ) ;
190188 }
191-
192189 /*
193190 * Today, MSAL.NET uses another redirect URI by default in desktop applications that run on Windows
194191 * (urn:ietf:wg:oauth:2.0:oob). In the future, we'll want to change this default, so we recommend
@@ -204,7 +201,7 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
204201 redirectUri = "http://localhost" ;
205202 }
206203#endif
207- PublicClientAppKey pcaKey = new PublicClientAppKey ( parameters . Authority , redirectUri , _applicationClientId
204+ PublicClientAppKey pcaKey = new ( parameters . Authority , redirectUri , _applicationClientId
208205#if NETFRAMEWORK
209206 , _iWin32WindowFunc
210207#endif
@@ -213,7 +210,7 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
213210#endif
214211 ) ;
215212
216- IPublicClientApplication app = GetPublicClientAppInstance ( pcaKey ) ;
213+ IPublicClientApplication app = await GetPublicClientAppInstanceAsync ( pcaKey , cts . Token ) . ConfigureAwait ( false ) ;
217214
218215 if ( parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryIntegrated )
219216 {
@@ -248,7 +245,7 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
248245 if ( null != previousPw &&
249246 previousPw is byte [ ] previousPwBytes &&
250247 // Only get the cached token if the current password hash matches the previously used password hash
251- currPwHash . SequenceEqual ( previousPwBytes ) )
248+ AreEqual ( currPwHash , previousPwBytes ) )
252249 {
253250 result = await TryAcquireTokenSilent ( app , parameters , scopes , cts ) . ConfigureAwait ( false ) ;
254251 }
@@ -353,7 +350,7 @@ private static async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlo
353350 {
354351 if ( authenticationMethod == SqlAuthenticationMethod . ActiveDirectoryInteractive )
355352 {
356- CancellationTokenSource ctsInteractive = new CancellationTokenSource ( ) ;
353+ CancellationTokenSource ctsInteractive = new ( ) ;
357354#if NETCOREAPP
358355 /*
359356 * On .NET Core, MSAL will start the system browser as a separate process. MSAL does not have control over this browser,
@@ -447,16 +444,69 @@ public Task<Uri> AcquireAuthorizationCodeAsync(Uri authorizationUri, Uri redirec
447444 => _acquireAuthorizationCodeAsyncCallback . Invoke ( authorizationUri , redirectUri , cancellationToken ) ;
448445 }
449446
450- private IPublicClientApplication GetPublicClientAppInstance ( PublicClientAppKey publicClientAppKey )
447+ private async Task < IPublicClientApplication > GetPublicClientAppInstanceAsync ( PublicClientAppKey publicClientAppKey , CancellationToken cancellationToken )
451448 {
452449 if ( ! s_pcaMap . TryGetValue ( publicClientAppKey , out IPublicClientApplication clientApplicationInstance ) )
453450 {
454- clientApplicationInstance = CreateClientAppInstance ( publicClientAppKey ) ;
455- s_pcaMap . TryAdd ( publicClientAppKey , clientApplicationInstance ) ;
451+ await s_pcaMapModifierSemaphore . WaitAsync ( cancellationToken ) ;
452+ try
453+ {
454+ // Double-check in case another thread added it while we waited for the semaphore
455+ if ( ! s_pcaMap . TryGetValue ( publicClientAppKey , out clientApplicationInstance ) )
456+ {
457+ clientApplicationInstance = CreateClientAppInstance ( publicClientAppKey ) ;
458+ s_pcaMap . TryAdd ( publicClientAppKey , clientApplicationInstance ) ;
459+ }
460+ }
461+ finally
462+ {
463+ s_pcaMapModifierSemaphore . Release ( ) ;
464+ }
456465 }
466+
457467 return clientApplicationInstance ;
458468 }
459469
470+ private static async Task < AccessToken > GetTokenAsync ( TokenCredentialKey tokenCredentialKey , string secret ,
471+ TokenRequestContext tokenRequestContext , CancellationToken cancellationToken )
472+ {
473+ if ( ! s_tokenCredentialMap . TryGetValue ( tokenCredentialKey , out TokenCredentialData tokenCredentialInstance ) )
474+ {
475+ await s_tokenCredentialMapModifierSemaphore . WaitAsync ( cancellationToken ) ;
476+ try
477+ {
478+ // Double-check in case another thread added it while we waited for the semaphore
479+ if ( ! s_tokenCredentialMap . TryGetValue ( tokenCredentialKey , out tokenCredentialInstance ) )
480+ {
481+ tokenCredentialInstance = CreateTokenCredentialInstance ( tokenCredentialKey , secret ) ;
482+ s_tokenCredentialMap . TryAdd ( tokenCredentialKey , tokenCredentialInstance ) ;
483+ }
484+ }
485+ finally
486+ {
487+ s_tokenCredentialMapModifierSemaphore . Release ( ) ;
488+ }
489+ }
490+
491+ if ( ! AreEqual ( tokenCredentialInstance . _secretHash , GetHash ( secret ) ) )
492+ {
493+ // If the secret hash has changed, we need to remove the old token credential instance and create a new one.
494+ await s_tokenCredentialMapModifierSemaphore . WaitAsync ( cancellationToken ) ;
495+ try
496+ {
497+ s_tokenCredentialMap . TryRemove ( tokenCredentialKey , out _ ) ;
498+ tokenCredentialInstance = CreateTokenCredentialInstance ( tokenCredentialKey , secret ) ;
499+ s_tokenCredentialMap . TryAdd ( tokenCredentialKey , tokenCredentialInstance ) ;
500+ }
501+ finally
502+ {
503+ s_tokenCredentialMapModifierSemaphore . Release ( ) ;
504+ }
505+ }
506+
507+ return await tokenCredentialInstance . _tokenCredential . GetTokenAsync ( tokenRequestContext , cancellationToken ) ;
508+ }
509+
460510 private static string GetAccountPwCacheKey ( SqlAuthenticationParameters parameters )
461511 {
462512 return parameters . Authority + "+" + parameters . UserId ;
@@ -470,6 +520,24 @@ private static byte[] GetHash(string input)
470520 return hashedBytes ;
471521 }
472522
523+ private static bool AreEqual ( byte [ ] a1 , byte [ ] a2 )
524+ {
525+ if ( ReferenceEquals ( a1 , a2 ) )
526+ {
527+ return true ;
528+ }
529+ else if ( a1 is null || a2 is null )
530+ {
531+ return false ;
532+ }
533+ else if ( a1 . Length != a2 . Length )
534+ {
535+ return false ;
536+ }
537+
538+ return a1 . AsSpan ( ) . SequenceEqual ( a2 . AsSpan ( ) ) ;
539+ }
540+
473541 private IPublicClientApplication CreateClientAppInstance ( PublicClientAppKey publicClientAppKey )
474542 {
475543 IPublicClientApplication publicClientApplication ;
@@ -513,6 +581,59 @@ private IPublicClientApplication CreateClientAppInstance(PublicClientAppKey publ
513581 return publicClientApplication ;
514582 }
515583
584+ private static TokenCredentialData CreateTokenCredentialInstance ( TokenCredentialKey tokenCredentialKey , string secret )
585+ {
586+ if ( tokenCredentialKey . _tokenCredentialType == typeof ( DefaultAzureCredential ) )
587+ {
588+ DefaultAzureCredentialOptions defaultAzureCredentialOptions = new ( )
589+ {
590+ AuthorityHost = new Uri ( tokenCredentialKey . _authority ) ,
591+ SharedTokenCacheTenantId = tokenCredentialKey . _audience ,
592+ VisualStudioCodeTenantId = tokenCredentialKey . _audience ,
593+ VisualStudioTenantId = tokenCredentialKey . _audience ,
594+ ExcludeInteractiveBrowserCredential = true // Force disabled, even though it's disabled by default to respect driver specifications.
595+ } ;
596+
597+ // Optionally set clientId when available
598+ if ( tokenCredentialKey . _clientId is not null )
599+ {
600+ defaultAzureCredentialOptions . ManagedIdentityClientId = tokenCredentialKey . _clientId ;
601+ defaultAzureCredentialOptions . SharedTokenCacheUsername = tokenCredentialKey . _clientId ;
602+ defaultAzureCredentialOptions . WorkloadIdentityClientId = tokenCredentialKey . _clientId ;
603+ }
604+
605+ return new TokenCredentialData ( new DefaultAzureCredential ( defaultAzureCredentialOptions ) , GetHash ( secret ) ) ;
606+ }
607+
608+ TokenCredentialOptions tokenCredentialOptions = new ( ) { AuthorityHost = new Uri ( tokenCredentialKey . _authority ) } ;
609+
610+ if ( tokenCredentialKey . _tokenCredentialType == typeof ( ManagedIdentityCredential ) )
611+ {
612+ return new TokenCredentialData ( new ManagedIdentityCredential ( tokenCredentialKey . _clientId , tokenCredentialOptions ) , GetHash ( secret ) ) ;
613+ }
614+ else if ( tokenCredentialKey . _tokenCredentialType == typeof ( ClientSecretCredential ) )
615+ {
616+ return new TokenCredentialData ( new ClientSecretCredential ( tokenCredentialKey . _audience , tokenCredentialKey . _clientId , secret , tokenCredentialOptions ) , GetHash ( secret ) ) ;
617+ }
618+ else if ( tokenCredentialKey . _tokenCredentialType == typeof ( WorkloadIdentityCredential ) )
619+ {
620+ // The WorkloadIdentityCredentialOptions object initialization populates its instance members
621+ // from the environment variables AZURE_TENANT_ID, AZURE_CLIENT_ID, AZURE_FEDERATED_TOKEN_FILE,
622+ // and AZURE_ADDITIONALLY_ALLOWED_TENANTS. AZURE_CLIENT_ID may be overridden by the User Id.
623+ WorkloadIdentityCredentialOptions options = new ( ) { AuthorityHost = new Uri ( tokenCredentialKey . _authority ) } ;
624+
625+ if ( tokenCredentialKey . _clientId is not null )
626+ {
627+ options . ClientId = tokenCredentialKey . _clientId ;
628+ }
629+
630+ return new TokenCredentialData ( new WorkloadIdentityCredential ( options ) , GetHash ( secret ) ) ;
631+ }
632+
633+ // This should never be reached, but if it is, throw an exception that will be noticed during development
634+ throw new ArgumentException ( nameof ( ActiveDirectoryAuthenticationProvider ) ) ;
635+ }
636+
516637 internal class PublicClientAppKey
517638 {
518639 public readonly string _authority ;
@@ -572,5 +693,52 @@ public override int GetHashCode() => Tuple.Create(_authority, _redirectUri, _app
572693#endif
573694 ) . GetHashCode ( ) ;
574695 }
696+
697+ internal class TokenCredentialData
698+ {
699+ public TokenCredential _tokenCredential ;
700+ public byte [ ] _secretHash ;
701+
702+ public TokenCredentialData ( TokenCredential tokenCredential , byte [ ] secretHash )
703+ {
704+ _tokenCredential = tokenCredential ;
705+ _secretHash = secretHash ;
706+ }
707+ }
708+
709+ internal class TokenCredentialKey
710+ {
711+ public readonly Type _tokenCredentialType ;
712+ public readonly string _authority ;
713+ public readonly string _scope ;
714+ public readonly string _audience ;
715+ public readonly string _clientId ;
716+
717+ public TokenCredentialKey ( Type tokenCredentialType , string authority , string scope , string audience , string clientId )
718+ {
719+ _tokenCredentialType = tokenCredentialType ;
720+ _authority = authority ;
721+ _scope = scope ;
722+ _audience = audience ;
723+ _clientId = clientId ;
724+ }
725+
726+ public override bool Equals ( object obj )
727+ {
728+ if ( obj != null && obj is TokenCredentialKey tcKey )
729+ {
730+ return string . CompareOrdinal ( nameof ( _tokenCredentialType ) , nameof ( tcKey . _tokenCredentialType ) ) == 0
731+ && string . CompareOrdinal ( _authority , tcKey . _authority ) == 0
732+ && string . CompareOrdinal ( _scope , tcKey . _scope ) == 0
733+ && string . CompareOrdinal ( _audience , tcKey . _audience ) == 0
734+ && string . CompareOrdinal ( _clientId , tcKey . _clientId ) == 0
735+ ;
736+ }
737+ return false ;
738+ }
739+
740+ public override int GetHashCode ( ) => Tuple . Create ( _tokenCredentialType , _authority , _scope , _audience , _clientId ) . GetHashCode ( ) ;
741+ }
742+
575743 }
576744}
0 commit comments