diff --git a/src/Runner.Listener/Runner.cs b/src/Runner.Listener/Runner.cs index 51df80b2491..0935afca46d 100644 --- a/src/Runner.Listener/Runner.cs +++ b/src/Runner.Listener/Runner.cs @@ -6,6 +6,7 @@ using System.Reflection; using System.Runtime.CompilerServices; using System.Security.Cryptography; +using System.Security.Claims; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -15,7 +16,9 @@ using GitHub.Runner.Listener.Check; using GitHub.Runner.Listener.Configuration; using GitHub.Runner.Sdk; +using GitHub.Services.OAuth; using GitHub.Services.WebApi; +using GitHub.Services.WebApi.Jwt; using Pipelines = GitHub.DistributedTask.Pipelines; namespace GitHub.Runner.Listener @@ -35,8 +38,11 @@ public sealed class Runner : RunnerService, IRunner private readonly ConcurrentQueue _authMigrationTelemetries = new(); private Task _authMigrationTelemetryTask; private readonly object _authMigrationTelemetryLock = new(); + private Task _authMigrationClaimsCheckTask; + private readonly object _authMigrationClaimsCheckLock = new(); private IRunnerServer _runnerServer; private CancellationTokenSource _authMigrationTelemetryTokenSource = new(); + private CancellationTokenSource _authMigrationClaimsCheckTokenSource = new(); // // Helps avoid excessive calls to Run Service when encountering non-retriable errors from /acquirejob. @@ -329,6 +335,7 @@ public async Task ExecuteCommand(CommandSettings command) } finally { + _authMigrationClaimsCheckTokenSource?.Cancel(); _authMigrationTelemetryTokenSource?.Cancel(); HostContext.AuthMigrationChanged -= HandleAuthMigrationChanged; _term.CancelKeyPress -= CtrlCHandler; @@ -756,6 +763,131 @@ private void HandleAuthMigrationChanged(object sender, AuthMigrationEventArgs e) _authMigrationTelemetryTask = ReportAuthMigrationTelemetryAsync(_authMigrationTelemetryTokenSource.Token); } } + + // only start the claims check task once auth migration is changed (enabled or disabled) + lock (_authMigrationClaimsCheckLock) + { + if (_authMigrationClaimsCheckTask == null) + { + _authMigrationClaimsCheckTask = CheckOAuthTokenClaimsAsync(_authMigrationClaimsCheckTokenSource.Token); + } + } + } + + private async Task CheckOAuthTokenClaimsAsync(CancellationToken token) + { + string[] expectedClaims = + [ + "owner_id", + "runner_id", + "runner_group_id", + "scale_set_id", + "is_ephemeral", + "labels" + ]; + + try + { + var credMgr = HostContext.GetService(); + while (!token.IsCancellationRequested) + { + try + { + await HostContext.Delay(TimeSpan.FromMinutes(100), token); + } + catch (TaskCanceledException) + { + // Ignore cancellation + } + + if (token.IsCancellationRequested) + { + break; + } + + if (!HostContext.AllowAuthMigration) + { + Trace.Info("Skip checking oauth token claims since auth migration is disabled."); + continue; + } + + var baselineCred = credMgr.LoadCredentials(allowAuthUrlV2: false); + var authV2Cred = credMgr.LoadCredentials(allowAuthUrlV2: true); + + if (!(baselineCred.Federated is VssOAuthCredential baselineVssOAuthCred) || + !(authV2Cred.Federated is VssOAuthCredential vssOAuthCredV2) || + baselineVssOAuthCred == null || + vssOAuthCredV2 == null) + { + Trace.Info("Skip checking oauth token claims for non-oauth credentials"); + continue; + } + + if (string.Equals(baselineVssOAuthCred.AuthorizationUrl.AbsoluteUri, vssOAuthCredV2.AuthorizationUrl.AbsoluteUri, StringComparison.OrdinalIgnoreCase)) + { + Trace.Info("Skip checking oauth token claims for same authorization url"); + continue; + } + + var baselineProvider = baselineVssOAuthCred.GetTokenProvider(baselineVssOAuthCred.AuthorizationUrl); + var v2Provider = vssOAuthCredV2.GetTokenProvider(vssOAuthCredV2.AuthorizationUrl); + try + { + using (var timeoutTokenSource = new CancellationTokenSource(TimeSpan.FromSeconds(30))) + using (var requestTokenSource = CancellationTokenSource.CreateLinkedTokenSource(token, timeoutTokenSource.Token)) + { + var baselineToken = await baselineProvider.GetTokenAsync(null, requestTokenSource.Token); + var v2Token = await v2Provider.GetTokenAsync(null, requestTokenSource.Token); + if (baselineToken is VssOAuthAccessToken baselineAccessToken && + v2Token is VssOAuthAccessToken v2AccessToken && + !string.IsNullOrEmpty(baselineAccessToken.Value) && + !string.IsNullOrEmpty(v2AccessToken.Value)) + { + var baselineJwt = JsonWebToken.Create(baselineAccessToken.Value); + var baselineClaims = baselineJwt.ExtractClaims(); + var v2Jwt = JsonWebToken.Create(v2AccessToken.Value); + var v2Claims = v2Jwt.ExtractClaims(); + + // Log extracted claims for debugging + Trace.Verbose($"Baseline token expected claims: {string.Join(", ", baselineClaims + .Where(c => expectedClaims.Contains(c.Type.ToLowerInvariant())) + .Select(c => $"{c.Type}:{c.Value}"))}"); + Trace.Verbose($"V2 token expected claims: {string.Join(", ", v2Claims + .Where(c => expectedClaims.Contains(c.Type.ToLowerInvariant())) + .Select(c => $"{c.Type}:{c.Value}"))}"); + + foreach (var claim in expectedClaims) + { + // if baseline has the claim, v2 should have it too with exactly same value. + if (baselineClaims.FirstOrDefault(c => c.Type.ToLowerInvariant() == claim) is Claim baselineClaim && + !string.IsNullOrEmpty(baselineClaim?.Value)) + { + var v2Claim = v2Claims.FirstOrDefault(c => c.Type.ToLowerInvariant() == claim); + if (v2Claim?.Value != baselineClaim.Value) + { + Trace.Info($"Token Claim mismatch between two issuers. Expected: {baselineClaim.Type}:{baselineClaim.Value}. Actual: {v2Claim?.Type ?? "Empty"}:{v2Claim?.Value ?? "Empty"}"); + HostContext.DeferAuthMigration(TimeSpan.FromMinutes(60), $"Expected claim {baselineClaim.Type}:{baselineClaim.Value} does not match {v2Claim?.Type ?? "Empty"}:{v2Claim?.Value ?? "Empty"}"); + break; + } + } + } + + Trace.Info("OAuth token claims check passed."); + } + } + } + catch (Exception ex) + { + Trace.Error("Failed to fetch and check OAuth token claims."); + Trace.Error(ex); + } + } + } + catch (Exception ex) + { + Trace.Error("Failed to check OAuth token claims in background."); + Trace.Error(ex); + } } private async Task ReportAuthMigrationTelemetryAsync(CancellationToken token)