Skip to content

Commit 51be41e

Browse files
TingluoHuangCopilot
authored andcommitted
Make sure the token's claims are match as expected. (actions#3846)
Co-authored-by: Copilot <[email protected]>
1 parent fce3dc7 commit 51be41e

File tree

1 file changed

+132
-0
lines changed

1 file changed

+132
-0
lines changed

src/Runner.Listener/Runner.cs

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using System.Reflection;
77
using System.Runtime.CompilerServices;
88
using System.Security.Cryptography;
9+
using System.Security.Claims;
910
using System.Text;
1011
using System.Threading;
1112
using System.Threading.Tasks;
@@ -15,7 +16,9 @@
1516
using GitHub.Runner.Listener.Check;
1617
using GitHub.Runner.Listener.Configuration;
1718
using GitHub.Runner.Sdk;
19+
using GitHub.Services.OAuth;
1820
using GitHub.Services.WebApi;
21+
using GitHub.Services.WebApi.Jwt;
1922
using Pipelines = GitHub.DistributedTask.Pipelines;
2023

2124
namespace GitHub.Runner.Listener
@@ -35,8 +38,11 @@ public sealed class Runner : RunnerService, IRunner
3538
private readonly ConcurrentQueue<string> _authMigrationTelemetries = new();
3639
private Task _authMigrationTelemetryTask;
3740
private readonly object _authMigrationTelemetryLock = new();
41+
private Task _authMigrationClaimsCheckTask;
42+
private readonly object _authMigrationClaimsCheckLock = new();
3843
private IRunnerServer _runnerServer;
3944
private CancellationTokenSource _authMigrationTelemetryTokenSource = new();
45+
private CancellationTokenSource _authMigrationClaimsCheckTokenSource = new();
4046

4147
// <summary>
4248
// Helps avoid excessive calls to Run Service when encountering non-retriable errors from /acquirejob.
@@ -329,6 +335,7 @@ public async Task<int> ExecuteCommand(CommandSettings command)
329335
}
330336
finally
331337
{
338+
_authMigrationClaimsCheckTokenSource?.Cancel();
332339
_authMigrationTelemetryTokenSource?.Cancel();
333340
HostContext.AuthMigrationChanged -= HandleAuthMigrationChanged;
334341
_term.CancelKeyPress -= CtrlCHandler;
@@ -756,6 +763,131 @@ private void HandleAuthMigrationChanged(object sender, AuthMigrationEventArgs e)
756763
_authMigrationTelemetryTask = ReportAuthMigrationTelemetryAsync(_authMigrationTelemetryTokenSource.Token);
757764
}
758765
}
766+
767+
// only start the claims check task once auth migration is changed (enabled or disabled)
768+
lock (_authMigrationClaimsCheckLock)
769+
{
770+
if (_authMigrationClaimsCheckTask == null)
771+
{
772+
_authMigrationClaimsCheckTask = CheckOAuthTokenClaimsAsync(_authMigrationClaimsCheckTokenSource.Token);
773+
}
774+
}
775+
}
776+
777+
private async Task CheckOAuthTokenClaimsAsync(CancellationToken token)
778+
{
779+
string[] expectedClaims =
780+
[
781+
"owner_id",
782+
"runner_id",
783+
"runner_group_id",
784+
"scale_set_id",
785+
"is_ephemeral",
786+
"labels"
787+
];
788+
789+
try
790+
{
791+
var credMgr = HostContext.GetService<ICredentialManager>();
792+
while (!token.IsCancellationRequested)
793+
{
794+
try
795+
{
796+
await HostContext.Delay(TimeSpan.FromMinutes(100), token);
797+
}
798+
catch (TaskCanceledException)
799+
{
800+
// Ignore cancellation
801+
}
802+
803+
if (token.IsCancellationRequested)
804+
{
805+
break;
806+
}
807+
808+
if (!HostContext.AllowAuthMigration)
809+
{
810+
Trace.Info("Skip checking oauth token claims since auth migration is disabled.");
811+
continue;
812+
}
813+
814+
var baselineCred = credMgr.LoadCredentials(allowAuthUrlV2: false);
815+
var authV2Cred = credMgr.LoadCredentials(allowAuthUrlV2: true);
816+
817+
if (!(baselineCred.Federated is VssOAuthCredential baselineVssOAuthCred) ||
818+
!(authV2Cred.Federated is VssOAuthCredential vssOAuthCredV2) ||
819+
baselineVssOAuthCred == null ||
820+
vssOAuthCredV2 == null)
821+
{
822+
Trace.Info("Skip checking oauth token claims for non-oauth credentials");
823+
continue;
824+
}
825+
826+
if (string.Equals(baselineVssOAuthCred.AuthorizationUrl.AbsoluteUri, vssOAuthCredV2.AuthorizationUrl.AbsoluteUri, StringComparison.OrdinalIgnoreCase))
827+
{
828+
Trace.Info("Skip checking oauth token claims for same authorization url");
829+
continue;
830+
}
831+
832+
var baselineProvider = baselineVssOAuthCred.GetTokenProvider(baselineVssOAuthCred.AuthorizationUrl);
833+
var v2Provider = vssOAuthCredV2.GetTokenProvider(vssOAuthCredV2.AuthorizationUrl);
834+
try
835+
{
836+
using (var timeoutTokenSource = new CancellationTokenSource(TimeSpan.FromSeconds(30)))
837+
using (var requestTokenSource = CancellationTokenSource.CreateLinkedTokenSource(token, timeoutTokenSource.Token))
838+
{
839+
var baselineToken = await baselineProvider.GetTokenAsync(null, requestTokenSource.Token);
840+
var v2Token = await v2Provider.GetTokenAsync(null, requestTokenSource.Token);
841+
if (baselineToken is VssOAuthAccessToken baselineAccessToken &&
842+
v2Token is VssOAuthAccessToken v2AccessToken &&
843+
!string.IsNullOrEmpty(baselineAccessToken.Value) &&
844+
!string.IsNullOrEmpty(v2AccessToken.Value))
845+
{
846+
var baselineJwt = JsonWebToken.Create(baselineAccessToken.Value);
847+
var baselineClaims = baselineJwt.ExtractClaims();
848+
var v2Jwt = JsonWebToken.Create(v2AccessToken.Value);
849+
var v2Claims = v2Jwt.ExtractClaims();
850+
851+
// Log extracted claims for debugging
852+
Trace.Verbose($"Baseline token expected claims: {string.Join(", ", baselineClaims
853+
.Where(c => expectedClaims.Contains(c.Type.ToLowerInvariant()))
854+
.Select(c => $"{c.Type}:{c.Value}"))}");
855+
Trace.Verbose($"V2 token expected claims: {string.Join(", ", v2Claims
856+
.Where(c => expectedClaims.Contains(c.Type.ToLowerInvariant()))
857+
.Select(c => $"{c.Type}:{c.Value}"))}");
858+
859+
foreach (var claim in expectedClaims)
860+
{
861+
// if baseline has the claim, v2 should have it too with exactly same value.
862+
if (baselineClaims.FirstOrDefault(c => c.Type.ToLowerInvariant() == claim) is Claim baselineClaim &&
863+
!string.IsNullOrEmpty(baselineClaim?.Value))
864+
{
865+
var v2Claim = v2Claims.FirstOrDefault(c => c.Type.ToLowerInvariant() == claim);
866+
if (v2Claim?.Value != baselineClaim.Value)
867+
{
868+
Trace.Info($"Token Claim mismatch between two issuers. Expected: {baselineClaim.Type}:{baselineClaim.Value}. Actual: {v2Claim?.Type ?? "Empty"}:{v2Claim?.Value ?? "Empty"}");
869+
HostContext.DeferAuthMigration(TimeSpan.FromMinutes(60), $"Expected claim {baselineClaim.Type}:{baselineClaim.Value} does not match {v2Claim?.Type ?? "Empty"}:{v2Claim?.Value ?? "Empty"}");
870+
break;
871+
}
872+
}
873+
}
874+
875+
Trace.Info("OAuth token claims check passed.");
876+
}
877+
}
878+
}
879+
catch (Exception ex)
880+
{
881+
Trace.Error("Failed to fetch and check OAuth token claims.");
882+
Trace.Error(ex);
883+
}
884+
}
885+
}
886+
catch (Exception ex)
887+
{
888+
Trace.Error("Failed to check OAuth token claims in background.");
889+
Trace.Error(ex);
890+
}
759891
}
760892

761893
private async Task ReportAuthMigrationTelemetryAsync(CancellationToken token)

0 commit comments

Comments
 (0)