Skip to content

Commit f72ac78

Browse files
fix: Change workloadidentity token cache expiry based on token expiry (#23100) (#23264)
Signed-off-by: Jagpreet Singh Tamber <[email protected]> Signed-off-by: Alexandre Gaudreault <[email protected]> Co-authored-by: Jagpreet Singh Tamber <[email protected]>
1 parent 001848e commit f72ac78

File tree

9 files changed

+300
-29
lines changed

9 files changed

+300
-29
lines changed

util/git/client_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"github.com/stretchr/testify/assert"
1919
"github.com/stretchr/testify/require"
2020

21+
"github.com/argoproj/argo-cd/v3/util/workloadidentity"
2122
"github.com/argoproj/argo-cd/v3/util/workloadidentity/mocks"
2223
)
2324

@@ -847,7 +848,7 @@ func Test_nativeGitClient_CommitAndPush(t *testing.T) {
847848

848849
func Test_newAuth_AzureWorkloadIdentity(t *testing.T) {
849850
tokenprovider := new(mocks.TokenProvider)
850-
tokenprovider.On("GetToken", azureDevopsEntraResourceId).Return("accessToken", nil)
851+
tokenprovider.On("GetToken", azureDevopsEntraResourceId).Return(&workloadidentity.Token{AccessToken: "accessToken"}, nil)
851852

852853
creds := AzureWorkloadIdentityCreds{store: NoopCredsStore{}, tokenProvider: tokenprovider}
853854

util/git/creds.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -735,16 +735,19 @@ func (creds AzureWorkloadIdentityCreds) getAccessToken(scope string) (string, er
735735

736736
t, found := azureTokenCache.Get(key)
737737
if found {
738-
return t.(string), nil
738+
return t.(*workloadidentity.Token).AccessToken, nil
739739
}
740740

741741
token, err := creds.tokenProvider.GetToken(scope)
742742
if err != nil {
743743
return "", fmt.Errorf("failed to get Azure access token: %w", err)
744744
}
745745

746-
azureTokenCache.Set(key, token, 2*time.Hour)
747-
return token, nil
746+
cacheExpiry := workloadidentity.CalculateCacheExpiryBasedOnTokenExpiry(token.ExpiresOn)
747+
if cacheExpiry > 0 {
748+
azureTokenCache.Set(key, token, cacheExpiry)
749+
}
750+
return token.AccessToken, nil
748751
}
749752

750753
func (creds AzureWorkloadIdentityCreds) GetAzureDevOpsAccessToken() (string, error) {

util/git/creds_test.go

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@ import (
88
"regexp"
99
"strings"
1010
"testing"
11+
"time"
1112

1213
"github.com/google/uuid"
14+
gocache "github.com/patrickmn/go-cache"
1315
"github.com/stretchr/testify/assert"
1416
"github.com/stretchr/testify/require"
1517
"golang.org/x/oauth2"
@@ -412,9 +414,10 @@ func TestGoogleCloudCreds_Environ_cleanup(t *testing.T) {
412414
}
413415

414416
func TestAzureWorkloadIdentityCreds_Environ(t *testing.T) {
417+
resetAzureTokenCache()
415418
store := &memoryCredsStore{creds: make(map[string]cred)}
416419
workloadIdentityMock := new(mocks.TokenProvider)
417-
workloadIdentityMock.On("GetToken", azureDevopsEntraResourceId).Return("accessToken", nil)
420+
workloadIdentityMock.On("GetToken", azureDevopsEntraResourceId).Return(&workloadidentity.Token{AccessToken: "accessToken", ExpiresOn: time.Now().Add(time.Minute)}, nil)
418421
creds := AzureWorkloadIdentityCreds{store, workloadIdentityMock}
419422
_, _, err := creds.Environ()
420423
require.NoError(t, err)
@@ -427,9 +430,10 @@ func TestAzureWorkloadIdentityCreds_Environ(t *testing.T) {
427430
}
428431

429432
func TestAzureWorkloadIdentityCreds_Environ_cleanup(t *testing.T) {
433+
resetAzureTokenCache()
430434
store := &memoryCredsStore{creds: make(map[string]cred)}
431435
workloadIdentityMock := new(mocks.TokenProvider)
432-
workloadIdentityMock.On("GetToken", azureDevopsEntraResourceId).Return("accessToken", nil)
436+
workloadIdentityMock.On("GetToken", azureDevopsEntraResourceId).Return(&workloadidentity.Token{AccessToken: "accessToken", ExpiresOn: time.Now().Add(time.Minute)}, nil)
433437
creds := AzureWorkloadIdentityCreds{store, workloadIdentityMock}
434438
closer, _, err := creds.Environ()
435439
require.NoError(t, err)
@@ -439,9 +443,10 @@ func TestAzureWorkloadIdentityCreds_Environ_cleanup(t *testing.T) {
439443
}
440444

441445
func TestAzureWorkloadIdentityCreds_GetUserInfo(t *testing.T) {
446+
resetAzureTokenCache()
442447
store := &memoryCredsStore{creds: make(map[string]cred)}
443448
workloadIdentityMock := new(mocks.TokenProvider)
444-
workloadIdentityMock.On("GetToken", azureDevopsEntraResourceId).Return("accessToken", nil)
449+
workloadIdentityMock.On("GetToken", azureDevopsEntraResourceId).Return(&workloadidentity.Token{AccessToken: "accessToken", ExpiresOn: time.Now().Add(time.Minute)}, nil)
445450
creds := AzureWorkloadIdentityCreds{store, workloadIdentityMock}
446451

447452
user, email, err := creds.GetUserInfo(t.Context())
@@ -456,3 +461,45 @@ func TestGetHelmCredsShouldReturnHelmCredsIfAzureWorkloadIdentityNotSpecified(t
456461
_, ok := creds.(AzureWorkloadIdentityCreds)
457462
require.Truef(t, ok, "expected HelmCreds but got %T", creds)
458463
}
464+
465+
func TestAzureWorkloadIdentityCreds_FetchNewTokenIfExistingIsExpired(t *testing.T) {
466+
resetAzureTokenCache()
467+
store := &memoryCredsStore{creds: make(map[string]cred)}
468+
workloadIdentityMock := new(mocks.TokenProvider)
469+
workloadIdentityMock.On("GetToken", azureDevopsEntraResourceId).
470+
Return(&workloadidentity.Token{AccessToken: "firstToken", ExpiresOn: time.Now().Add(time.Minute)}, nil).Once()
471+
workloadIdentityMock.On("GetToken", azureDevopsEntraResourceId).
472+
Return(&workloadidentity.Token{AccessToken: "secondToken"}, nil).Once()
473+
creds := AzureWorkloadIdentityCreds{store, workloadIdentityMock}
474+
token, err := creds.GetAzureDevOpsAccessToken()
475+
require.NoError(t, err)
476+
477+
assert.Equal(t, "firstToken", token)
478+
time.Sleep(5 * time.Second)
479+
token, err = creds.GetAzureDevOpsAccessToken()
480+
require.NoError(t, err)
481+
assert.Equal(t, "secondToken", token)
482+
}
483+
484+
func TestAzureWorkloadIdentityCreds_ReuseTokenIfExistingIsNotExpired(t *testing.T) {
485+
resetAzureTokenCache()
486+
store := &memoryCredsStore{creds: make(map[string]cred)}
487+
workloadIdentityMock := new(mocks.TokenProvider)
488+
firstToken := &workloadidentity.Token{AccessToken: "firstToken", ExpiresOn: time.Now().Add(6 * time.Minute)}
489+
secondToken := &workloadidentity.Token{AccessToken: "secondToken"}
490+
workloadIdentityMock.On("GetToken", azureDevopsEntraResourceId).Return(firstToken, nil).Once()
491+
workloadIdentityMock.On("GetToken", azureDevopsEntraResourceId).Return(secondToken, nil).Once()
492+
creds := AzureWorkloadIdentityCreds{store, workloadIdentityMock}
493+
token, err := creds.GetAzureDevOpsAccessToken()
494+
require.NoError(t, err)
495+
496+
assert.Equal(t, "firstToken", token)
497+
time.Sleep(5 * time.Second)
498+
token, err = creds.GetAzureDevOpsAccessToken()
499+
require.NoError(t, err)
500+
assert.Equal(t, "firstToken", token)
501+
}
502+
503+
func resetAzureTokenCache() {
504+
azureTokenCache = gocache.New(gocache.NoExpiration, 0)
505+
}

util/helm/client_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"gopkg.in/yaml.v2"
1919

2020
"github.com/argoproj/argo-cd/v3/util/io"
21+
"github.com/argoproj/argo-cd/v3/util/workloadidentity"
2122
"github.com/argoproj/argo-cd/v3/util/workloadidentity/mocks"
2223
)
2324

@@ -300,7 +301,7 @@ func TestGetTagsFromURLPrivateRepoWithAzureWorkloadIdentityAuthentication(t *tes
300301
}
301302

302303
workloadIdentityMock := new(mocks.TokenProvider)
303-
workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return("accessToken", nil)
304+
workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return(&workloadidentity.Token{AccessToken: "accessToken"}, nil)
304305

305306
mockServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
306307
t.Logf("called %s", r.URL.Path)

util/helm/creds.go

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ import (
1111
"strings"
1212
"time"
1313

14+
"github.com/golang-jwt/jwt/v5"
1415
gocache "github.com/patrickmn/go-cache"
16+
log "github.com/sirupsen/logrus"
1517

1618
argoutils "github.com/argoproj/argo-cd/v3/util"
1719
"github.com/argoproj/argo-cd/v3/util/env"
@@ -146,11 +148,33 @@ func (creds AzureWorkloadIdentityCreds) GetAccessToken() (string, error) {
146148
return "", fmt.Errorf("failed to get Azure access token after challenge: %w", err)
147149
}
148150

149-
// Access token has a lifetime of 3 hours
150-
storeAzureToken(key, token, 2*time.Hour)
151+
tokenExpiry, err := getJWTExpiry(token)
152+
if err != nil {
153+
log.Warnf("failed to get token expiry from JWT: %v, using current time as fallback", err)
154+
tokenExpiry = time.Now()
155+
}
156+
157+
cacheExpiry := workloadidentity.CalculateCacheExpiryBasedOnTokenExpiry(tokenExpiry)
158+
if cacheExpiry > 0 {
159+
storeAzureToken(key, token, cacheExpiry)
160+
}
151161
return token, nil
152162
}
153163

164+
func getJWTExpiry(token string) (time.Time, error) {
165+
parser := jwt.NewParser()
166+
claims := jwt.MapClaims{}
167+
_, _, err := parser.ParseUnverified(token, claims)
168+
if err != nil {
169+
return time.Time{}, fmt.Errorf("failed to parse JWT: %w", err)
170+
}
171+
exp, err := claims.GetExpirationTime()
172+
if err != nil {
173+
return time.Time{}, fmt.Errorf("'exp' claim not found or invalid in token: %w", err)
174+
}
175+
return time.UnixMilli(exp.UnixMilli()), nil
176+
}
177+
154178
func (creds AzureWorkloadIdentityCreds) getAccessTokenAfterChallenge(tokenParams map[string]string) (string, error) {
155179
realm := tokenParams["realm"]
156180
service := tokenParams["service"]
@@ -177,7 +201,7 @@ func (creds AzureWorkloadIdentityCreds) getAccessTokenAfterChallenge(tokenParams
177201
formValues := url.Values{}
178202
formValues.Add("grant_type", "access_token")
179203
formValues.Add("service", service)
180-
formValues.Add("access_token", armAccessToken)
204+
formValues.Add("access_token", armAccessToken.AccessToken)
181205

182206
resp, err := client.PostForm(refreshTokenURL, formValues)
183207
if err != nil {

util/helm/creds_test.go

Lines changed: 128 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ import (
77
"testing"
88
"time"
99

10+
"github.com/golang-jwt/jwt/v5"
11+
gocache "github.com/patrickmn/go-cache"
1012
"github.com/stretchr/testify/assert"
1113
"github.com/stretchr/testify/require"
1214

@@ -79,7 +81,7 @@ func TestGetPasswordShouldGenerateTokenIfNotPresentInCache(t *testing.T) {
7981
defer mockServer.Close()
8082

8183
workloadIdentityMock := new(mocks.TokenProvider)
82-
workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return("accessToken", nil)
84+
workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return(&workloadidentity.Token{AccessToken: "accessToken"}, nil)
8385
creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock)
8486

8587
// Retrieve the token from the cache
@@ -191,7 +193,7 @@ func TestGetAccessTokenAfterChallenge_Success(t *testing.T) {
191193
defer mockServer.Close()
192194

193195
workloadIdentityMock := new(mocks.TokenProvider)
194-
workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return("accessToken", nil)
196+
workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return(&workloadidentity.Token{AccessToken: "accessToken"}, nil)
195197
creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock)
196198

197199
tokenParams := map[string]string{
@@ -216,7 +218,7 @@ func TestGetAccessTokenAfterChallenge_Failure(t *testing.T) {
216218

217219
// Create an instance of AzureWorkloadIdentityCreds with the mock credential wrapper
218220
workloadIdentityMock := new(mocks.TokenProvider)
219-
workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return("accessToken", nil)
221+
workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return(&workloadidentity.Token{AccessToken: "accessToken"}, nil)
220222
creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock)
221223

222224
tokenParams := map[string]string{
@@ -241,7 +243,7 @@ func TestGetAccessTokenAfterChallenge_MalformedResponse(t *testing.T) {
241243

242244
// Create an instance of AzureWorkloadIdentityCreds with the mock credential wrapper
243245
workloadIdentityMock := new(mocks.TokenProvider)
244-
workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return("accessToken", nil)
246+
workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return(&workloadidentity.Token{AccessToken: "accessToken"}, nil)
245247
creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock)
246248

247249
tokenParams := map[string]string{
@@ -253,3 +255,125 @@ func TestGetAccessTokenAfterChallenge_MalformedResponse(t *testing.T) {
253255
require.ErrorContains(t, err, "failed to unmarshal response body")
254256
assert.Empty(t, refreshToken)
255257
}
258+
259+
// Helper to generate a mock JWT token with a given expiry time
260+
func generateMockJWT(expiry time.Time) (string, error) {
261+
claims := jwt.MapClaims{
262+
"exp": expiry.Unix(),
263+
}
264+
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
265+
// Use a dummy secret for signing
266+
return token.SignedString([]byte("dummy-secret"))
267+
}
268+
269+
func TestGetAccessToken_FetchNewTokenIfExistingIsExpired(t *testing.T) {
270+
resetAzureTokenCache()
271+
accessToken1, _ := generateMockJWT(time.Now().Add(1 * time.Minute))
272+
accessToken2, _ := generateMockJWT(time.Now().Add(1 * time.Minute))
273+
274+
mockServerURL := ""
275+
mockedServerURL := func() string {
276+
return mockServerURL
277+
}
278+
279+
callCount := 0
280+
mockServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
281+
switch r.URL.Path {
282+
case "/v2/":
283+
assert.Equal(t, "/v2/", r.URL.Path)
284+
w.Header().Set("Www-Authenticate", fmt.Sprintf(`Bearer realm="%s",service="%s"`, mockedServerURL(), mockedServerURL()[8:]))
285+
w.WriteHeader(http.StatusUnauthorized)
286+
case "/oauth2/exchange":
287+
assert.Equal(t, "/oauth2/exchange", r.URL.Path)
288+
var response string
289+
switch callCount {
290+
case 0:
291+
response = fmt.Sprintf(`{"refresh_token": "%s"}`, accessToken1)
292+
case 1:
293+
response = fmt.Sprintf(`{"refresh_token": "%s"}`, accessToken2)
294+
default:
295+
response = `{"refresh_token": "defaultToken"}`
296+
}
297+
callCount++
298+
w.WriteHeader(http.StatusOK)
299+
_, err := w.Write([]byte(response))
300+
require.NoError(t, err)
301+
default:
302+
http.NotFound(w, r)
303+
}
304+
}))
305+
defer mockServer.Close()
306+
mockServerURL = mockServer.URL
307+
308+
workloadIdentityMock := new(mocks.TokenProvider)
309+
workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return(&workloadidentity.Token{AccessToken: "accessToken"}, nil)
310+
creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock)
311+
312+
refreshToken, err := creds.GetAccessToken()
313+
require.NoError(t, err)
314+
assert.Equal(t, accessToken1, refreshToken)
315+
316+
time.Sleep(5 * time.Second) // Wait for the token to expire
317+
318+
refreshToken, err = creds.GetAccessToken()
319+
require.NoError(t, err)
320+
assert.Equal(t, accessToken2, refreshToken)
321+
}
322+
323+
func TestGetAccessToken_ReuseTokenIfExistingIsNotExpired(t *testing.T) {
324+
resetAzureTokenCache()
325+
accessToken1, _ := generateMockJWT(time.Now().Add(6 * time.Minute))
326+
accessToken2, _ := generateMockJWT(time.Now().Add(1 * time.Minute))
327+
328+
mockServerURL := ""
329+
mockedServerURL := func() string {
330+
return mockServerURL
331+
}
332+
333+
callCount := 0
334+
mockServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
335+
switch r.URL.Path {
336+
case "/v2/":
337+
assert.Equal(t, "/v2/", r.URL.Path)
338+
w.Header().Set("Www-Authenticate", fmt.Sprintf(`Bearer realm="%s",service="%s"`, mockedServerURL(), mockedServerURL()[8:]))
339+
w.WriteHeader(http.StatusUnauthorized)
340+
case "/oauth2/exchange":
341+
assert.Equal(t, "/oauth2/exchange", r.URL.Path)
342+
var response string
343+
switch callCount {
344+
case 0:
345+
response = fmt.Sprintf(`{"refresh_token": "%s"}`, accessToken1)
346+
case 1:
347+
response = fmt.Sprintf(`{"refresh_token": "%s"}`, accessToken2)
348+
default:
349+
response = `{"refresh_token": "defaultToken"}`
350+
}
351+
callCount++
352+
w.WriteHeader(http.StatusOK)
353+
_, err := w.Write([]byte(response))
354+
require.NoError(t, err)
355+
default:
356+
http.NotFound(w, r)
357+
}
358+
}))
359+
defer mockServer.Close()
360+
mockServerURL = mockServer.URL
361+
362+
workloadIdentityMock := new(mocks.TokenProvider)
363+
workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return(&workloadidentity.Token{AccessToken: "accessToken"}, nil)
364+
creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock)
365+
366+
refreshToken, err := creds.GetAccessToken()
367+
require.NoError(t, err)
368+
assert.Equal(t, accessToken1, refreshToken)
369+
370+
time.Sleep(5 * time.Second) // Wait for the token to expire
371+
372+
refreshToken, err = creds.GetAccessToken()
373+
require.NoError(t, err)
374+
assert.Equal(t, accessToken1, refreshToken)
375+
}
376+
377+
func resetAzureTokenCache() {
378+
azureTokenCache = gocache.New(gocache.NoExpiration, 0)
379+
}

0 commit comments

Comments
 (0)