Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions src/Runner.Listener/Runner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Security.Cryptography;
using System.Security.Claims;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
Expand All @@ -15,7 +16,9 @@
using GitHub.Runner.Listener.Check;
using GitHub.Runner.Listener.Configuration;
using GitHub.Runner.Sdk;
using GitHub.Services.OAuth;
using GitHub.Services.WebApi;
using GitHub.Services.WebApi.Jwt;
using Pipelines = GitHub.DistributedTask.Pipelines;

namespace GitHub.Runner.Listener
Expand All @@ -35,8 +38,11 @@ public sealed class Runner : RunnerService, IRunner
private readonly ConcurrentQueue<string> _authMigrationTelemetries = new();
private Task _authMigrationTelemetryTask;
private readonly object _authMigrationTelemetryLock = new();
private Task _authMigrationClaimsCheckTask;
private readonly object _authMigrationClaimsCheckLock = new();
private IRunnerServer _runnerServer;
private CancellationTokenSource _authMigrationTelemetryTokenSource = new();
private CancellationTokenSource _authMigrationClaimsCheckTokenSource = new();

// <summary>
// Helps avoid excessive calls to Run Service when encountering non-retriable errors from /acquirejob.
Expand Down Expand Up @@ -329,6 +335,7 @@ public async Task<int> ExecuteCommand(CommandSettings command)
}
finally
{
_authMigrationClaimsCheckTokenSource?.Cancel();
_authMigrationTelemetryTokenSource?.Cancel();
HostContext.AuthMigrationChanged -= HandleAuthMigrationChanged;
_term.CancelKeyPress -= CtrlCHandler;
Expand Down Expand Up @@ -756,6 +763,131 @@ private void HandleAuthMigrationChanged(object sender, AuthMigrationEventArgs e)
_authMigrationTelemetryTask = ReportAuthMigrationTelemetryAsync(_authMigrationTelemetryTokenSource.Token);
}
}

// only start the claims check task once auth migration is changed (enabled or disabled)
lock (_authMigrationClaimsCheckLock)
{
if (_authMigrationClaimsCheckTask == null)
{
_authMigrationClaimsCheckTask = CheckOAuthTokenClaimsAsync(_authMigrationClaimsCheckTokenSource.Token);
}
}
}

private async Task CheckOAuthTokenClaimsAsync(CancellationToken token)
{
string[] expectedClaims =
[
"owner_id",
"runner_id",
"runner_group_id",
"scale_set_id",
"is_ephemeral",
"labels"
];

try
{
var credMgr = HostContext.GetService<ICredentialManager>();
while (!token.IsCancellationRequested)
{
try
{
await HostContext.Delay(TimeSpan.FromMinutes(100), token);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this only delay after the first iteration? Or first successful iteration? Or does it matter?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't want the runner to get multiple tokens during start up, try to avoid too much load on the service.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it matter than an intermittent network IO error may cause the logic further below not be skipped for more than multiple 100m delays?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not see any claim diff from the token.
We will use Ring 0 to fully test this out and make sure we never see any failure.

}
catch (TaskCanceledException)
{
// Ignore cancellation
}

if (token.IsCancellationRequested)
{
break;
}

if (!HostContext.AllowAuthMigration)
{
Trace.Info("Skip checking oauth token claims since auth migration is disabled.");
continue;
}

var baselineCred = credMgr.LoadCredentials(allowAuthUrlV2: false);
var authV2Cred = credMgr.LoadCredentials(allowAuthUrlV2: true);

if (!(baselineCred.Federated is VssOAuthCredential baselineVssOAuthCred) ||
!(authV2Cred.Federated is VssOAuthCredential vssOAuthCredV2) ||
baselineVssOAuthCred == null ||
vssOAuthCredV2 == null)
{
Trace.Info("Skip checking oauth token claims for non-oauth credentials");
continue;
}

if (string.Equals(baselineVssOAuthCred.AuthorizationUrl.AbsoluteUri, vssOAuthCredV2.AuthorizationUrl.AbsoluteUri, StringComparison.OrdinalIgnoreCase))
{
Trace.Info("Skip checking oauth token claims for same authorization url");
continue;
}

var baselineProvider = baselineVssOAuthCred.GetTokenProvider(baselineVssOAuthCred.AuthorizationUrl);
var v2Provider = vssOAuthCredV2.GetTokenProvider(vssOAuthCredV2.AuthorizationUrl);
try
{
using (var timeoutTokenSource = new CancellationTokenSource(TimeSpan.FromSeconds(30)))
using (var requestTokenSource = CancellationTokenSource.CreateLinkedTokenSource(token, timeoutTokenSource.Token))
{
var baselineToken = await baselineProvider.GetTokenAsync(null, requestTokenSource.Token);
var v2Token = await v2Provider.GetTokenAsync(null, requestTokenSource.Token);
if (baselineToken is VssOAuthAccessToken baselineAccessToken &&
v2Token is VssOAuthAccessToken v2AccessToken &&
!string.IsNullOrEmpty(baselineAccessToken.Value) &&
!string.IsNullOrEmpty(v2AccessToken.Value))
{
var baselineJwt = JsonWebToken.Create(baselineAccessToken.Value);
var baselineClaims = baselineJwt.ExtractClaims();
var v2Jwt = JsonWebToken.Create(v2AccessToken.Value);
var v2Claims = v2Jwt.ExtractClaims();

// Log extracted claims for debugging
Trace.Verbose($"Baseline token expected claims: {string.Join(", ", baselineClaims
.Where(c => expectedClaims.Contains(c.Type.ToLowerInvariant()))
.Select(c => $"{c.Type}:{c.Value}"))}");
Trace.Verbose($"V2 token expected claims: {string.Join(", ", v2Claims
.Where(c => expectedClaims.Contains(c.Type.ToLowerInvariant()))
.Select(c => $"{c.Type}:{c.Value}"))}");

foreach (var claim in expectedClaims)
{
// if baseline has the claim, v2 should have it too with exactly same value.
if (baselineClaims.FirstOrDefault(c => c.Type.ToLowerInvariant() == claim) is Claim baselineClaim &&
!string.IsNullOrEmpty(baselineClaim?.Value))
{
var v2Claim = v2Claims.FirstOrDefault(c => c.Type.ToLowerInvariant() == claim);
if (v2Claim?.Value != baselineClaim.Value)
{
Trace.Info($"Token Claim mismatch between two issuers. Expected: {baselineClaim.Type}:{baselineClaim.Value}. Actual: {v2Claim?.Type ?? "Empty"}:{v2Claim?.Value ?? "Empty"}");
HostContext.DeferAuthMigration(TimeSpan.FromMinutes(60), $"Expected claim {baselineClaim.Type}:{baselineClaim.Value} does not match {v2Claim?.Type ?? "Empty"}:{v2Claim?.Value ?? "Empty"}");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this trigger any telemetry? Wondering how we know whether this is happening.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed the while loop sleeps for 100m yet this for loop defers 60m. Is that discrepancy OK?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any DeferAuthMigration() should publish telemetry back to service.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's on purposely to not lines up so different things won't trigger at the same time.

break;
}
}
}

Trace.Info("OAuth token claims check passed.");
}
}
}
catch (Exception ex)
{
Trace.Error("Failed to fetch and check OAuth token claims.");
Trace.Error(ex);
}
}
}
catch (Exception ex)
{
Trace.Error("Failed to check OAuth token claims in background.");
Trace.Error(ex);
}
}

private async Task ReportAuthMigrationTelemetryAsync(CancellationToken token)
Expand Down