From 009d1c3adaf00b5b09b3886cad5faa5fa64ffe2e Mon Sep 17 00:00:00 2001 From: Jagpreet Singh Tamber Date: Mon, 2 Jun 2025 16:51:30 -0400 Subject: [PATCH 1/2] fix: #23100 Change workloadidentity token cache expiry based on token expiry. (#23133) Signed-off-by: Jagpreet Singh Tamber Signed-off-by: Alexandre Gaudreault --- util/git/client_test.go | 3 +- util/git/creds.go | 9 +- util/git/creds_test.go | 53 ++++++- util/helm/client_test.go | 3 +- util/helm/creds.go | 30 +++- util/helm/creds_test.go | 132 +++++++++++++++++- util/workloadidentity/mocks/TokenProvider.go | 87 ++++++++---- util/workloadidentity/workloadidentity.go | 22 ++- .../workloadidentity/workloadidentity_test.go | 60 +++++++- 9 files changed, 352 insertions(+), 47 deletions(-) diff --git a/util/git/client_test.go b/util/git/client_test.go index e74eb62fe17ed..74b077ee2b0d3 100644 --- a/util/git/client_test.go +++ b/util/git/client_test.go @@ -18,6 +18,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/argoproj/argo-cd/v3/util/workloadidentity" "github.com/argoproj/argo-cd/v3/util/workloadidentity/mocks" ) @@ -847,7 +848,7 @@ func Test_nativeGitClient_CommitAndPush(t *testing.T) { func Test_newAuth_AzureWorkloadIdentity(t *testing.T) { tokenprovider := new(mocks.TokenProvider) - tokenprovider.On("GetToken", azureDevopsEntraResourceId).Return("accessToken", nil) + tokenprovider.On("GetToken", azureDevopsEntraResourceId).Return(&workloadidentity.Token{AccessToken: "accessToken"}, nil) creds := AzureWorkloadIdentityCreds{store: NoopCredsStore{}, tokenProvider: tokenprovider} diff --git a/util/git/creds.go b/util/git/creds.go index 3d166162b2c8c..062a364d7cbed 100644 --- a/util/git/creds.go +++ b/util/git/creds.go @@ -735,7 +735,7 @@ func (creds AzureWorkloadIdentityCreds) getAccessToken(scope string) (string, er t, found := azureTokenCache.Get(key) if found { - return t.(string), nil + return t.(*workloadidentity.Token).AccessToken, nil } token, err := creds.tokenProvider.GetToken(scope) @@ -743,8 +743,11 @@ func (creds AzureWorkloadIdentityCreds) getAccessToken(scope string) (string, er return "", fmt.Errorf("failed to get Azure access token: %w", err) } - azureTokenCache.Set(key, token, 2*time.Hour) - return token, nil + cacheExpiry := workloadidentity.CalculateCacheExpiryBasedOnTokenExpiry(token.ExpiresOn) + if cacheExpiry > 0 { + azureTokenCache.Set(key, token, cacheExpiry) + } + return token.AccessToken, nil } func (creds AzureWorkloadIdentityCreds) GetAzureDevOpsAccessToken() (string, error) { diff --git a/util/git/creds_test.go b/util/git/creds_test.go index 97b8337128c3c..1faecfe8dfc04 100644 --- a/util/git/creds_test.go +++ b/util/git/creds_test.go @@ -8,8 +8,10 @@ import ( "regexp" "strings" "testing" + "time" "github.com/google/uuid" + gocache "github.com/patrickmn/go-cache" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/oauth2" @@ -412,9 +414,10 @@ func TestGoogleCloudCreds_Environ_cleanup(t *testing.T) { } func TestAzureWorkloadIdentityCreds_Environ(t *testing.T) { + resetAzureTokenCache() store := &memoryCredsStore{creds: make(map[string]cred)} workloadIdentityMock := new(mocks.TokenProvider) - workloadIdentityMock.On("GetToken", azureDevopsEntraResourceId).Return("accessToken", nil) + workloadIdentityMock.On("GetToken", azureDevopsEntraResourceId).Return(&workloadidentity.Token{AccessToken: "accessToken", ExpiresOn: time.Now().Add(time.Minute)}, nil) creds := AzureWorkloadIdentityCreds{store, workloadIdentityMock} _, _, err := creds.Environ() require.NoError(t, err) @@ -427,9 +430,10 @@ func TestAzureWorkloadIdentityCreds_Environ(t *testing.T) { } func TestAzureWorkloadIdentityCreds_Environ_cleanup(t *testing.T) { + resetAzureTokenCache() store := &memoryCredsStore{creds: make(map[string]cred)} workloadIdentityMock := new(mocks.TokenProvider) - workloadIdentityMock.On("GetToken", azureDevopsEntraResourceId).Return("accessToken", nil) + workloadIdentityMock.On("GetToken", azureDevopsEntraResourceId).Return(&workloadidentity.Token{AccessToken: "accessToken", ExpiresOn: time.Now().Add(time.Minute)}, nil) creds := AzureWorkloadIdentityCreds{store, workloadIdentityMock} closer, _, err := creds.Environ() require.NoError(t, err) @@ -439,9 +443,10 @@ func TestAzureWorkloadIdentityCreds_Environ_cleanup(t *testing.T) { } func TestAzureWorkloadIdentityCreds_GetUserInfo(t *testing.T) { + resetAzureTokenCache() store := &memoryCredsStore{creds: make(map[string]cred)} workloadIdentityMock := new(mocks.TokenProvider) - workloadIdentityMock.On("GetToken", azureDevopsEntraResourceId).Return("accessToken", nil) + workloadIdentityMock.On("GetToken", azureDevopsEntraResourceId).Return(&workloadidentity.Token{AccessToken: "accessToken", ExpiresOn: time.Now().Add(time.Minute)}, nil) creds := AzureWorkloadIdentityCreds{store, workloadIdentityMock} user, email, err := creds.GetUserInfo(t.Context()) @@ -456,3 +461,45 @@ func TestGetHelmCredsShouldReturnHelmCredsIfAzureWorkloadIdentityNotSpecified(t _, ok := creds.(AzureWorkloadIdentityCreds) require.Truef(t, ok, "expected HelmCreds but got %T", creds) } + +func TestAzureWorkloadIdentityCreds_FetchNewTokenIfExistingIsExpired(t *testing.T) { + resetAzureTokenCache() + store := &memoryCredsStore{creds: make(map[string]cred)} + workloadIdentityMock := new(mocks.TokenProvider) + workloadIdentityMock.On("GetToken", azureDevopsEntraResourceId). + Return(&workloadidentity.Token{AccessToken: "firstToken", ExpiresOn: time.Now().Add(time.Minute)}, nil).Once() + workloadIdentityMock.On("GetToken", azureDevopsEntraResourceId). + Return(&workloadidentity.Token{AccessToken: "secondToken"}, nil).Once() + creds := AzureWorkloadIdentityCreds{store, workloadIdentityMock} + token, err := creds.GetAzureDevOpsAccessToken() + require.NoError(t, err) + + assert.Equal(t, "firstToken", token) + time.Sleep(5 * time.Second) + token, err = creds.GetAzureDevOpsAccessToken() + require.NoError(t, err) + assert.Equal(t, "secondToken", token) +} + +func TestAzureWorkloadIdentityCreds_ReuseTokenIfExistingIsNotExpired(t *testing.T) { + resetAzureTokenCache() + store := &memoryCredsStore{creds: make(map[string]cred)} + workloadIdentityMock := new(mocks.TokenProvider) + firstToken := &workloadidentity.Token{AccessToken: "firstToken", ExpiresOn: time.Now().Add(6 * time.Minute)} + secondToken := &workloadidentity.Token{AccessToken: "secondToken"} + workloadIdentityMock.On("GetToken", azureDevopsEntraResourceId).Return(firstToken, nil).Once() + workloadIdentityMock.On("GetToken", azureDevopsEntraResourceId).Return(secondToken, nil).Once() + creds := AzureWorkloadIdentityCreds{store, workloadIdentityMock} + token, err := creds.GetAzureDevOpsAccessToken() + require.NoError(t, err) + + assert.Equal(t, "firstToken", token) + time.Sleep(5 * time.Second) + token, err = creds.GetAzureDevOpsAccessToken() + require.NoError(t, err) + assert.Equal(t, "firstToken", token) +} + +func resetAzureTokenCache() { + azureTokenCache = gocache.New(gocache.NoExpiration, 0) +} diff --git a/util/helm/client_test.go b/util/helm/client_test.go index 52fa3179b4821..132e6d0a92868 100644 --- a/util/helm/client_test.go +++ b/util/helm/client_test.go @@ -18,6 +18,7 @@ import ( "gopkg.in/yaml.v2" "github.com/argoproj/argo-cd/v3/util/io" + "github.com/argoproj/argo-cd/v3/util/workloadidentity" "github.com/argoproj/argo-cd/v3/util/workloadidentity/mocks" ) @@ -300,7 +301,7 @@ func TestGetTagsFromURLPrivateRepoWithAzureWorkloadIdentityAuthentication(t *tes } workloadIdentityMock := new(mocks.TokenProvider) - workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return("accessToken", nil) + workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return(&workloadidentity.Token{AccessToken: "accessToken"}, nil) mockServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Logf("called %s", r.URL.Path) diff --git a/util/helm/creds.go b/util/helm/creds.go index 3b6001f757596..c82f53a25f243 100644 --- a/util/helm/creds.go +++ b/util/helm/creds.go @@ -11,7 +11,9 @@ import ( "strings" "time" + "github.com/golang-jwt/jwt/v5" gocache "github.com/patrickmn/go-cache" + log "github.com/sirupsen/logrus" argoutils "github.com/argoproj/argo-cd/v3/util" "github.com/argoproj/argo-cd/v3/util/env" @@ -146,11 +148,33 @@ func (creds AzureWorkloadIdentityCreds) GetAccessToken() (string, error) { return "", fmt.Errorf("failed to get Azure access token after challenge: %w", err) } - // Access token has a lifetime of 3 hours - storeAzureToken(key, token, 2*time.Hour) + tokenExpiry, err := getJWTExpiry(token) + if err != nil { + log.Warnf("failed to get token expiry from JWT: %v, using current time as fallback", err) + tokenExpiry = time.Now() + } + + cacheExpiry := workloadidentity.CalculateCacheExpiryBasedOnTokenExpiry(tokenExpiry) + if cacheExpiry > 0 { + storeAzureToken(key, token, cacheExpiry) + } return token, nil } +func getJWTExpiry(token string) (time.Time, error) { + parser := jwt.NewParser() + claims := jwt.MapClaims{} + _, _, err := parser.ParseUnverified(token, claims) + if err != nil { + return time.Time{}, fmt.Errorf("failed to parse JWT: %w", err) + } + exp, err := claims.GetExpirationTime() + if err != nil { + return time.Time{}, fmt.Errorf("'exp' claim not found or invalid in token: %w", err) + } + return time.UnixMilli(exp.UnixMilli()), nil +} + func (creds AzureWorkloadIdentityCreds) getAccessTokenAfterChallenge(tokenParams map[string]string) (string, error) { realm := tokenParams["realm"] service := tokenParams["service"] @@ -177,7 +201,7 @@ func (creds AzureWorkloadIdentityCreds) getAccessTokenAfterChallenge(tokenParams formValues := url.Values{} formValues.Add("grant_type", "access_token") formValues.Add("service", service) - formValues.Add("access_token", armAccessToken) + formValues.Add("access_token", armAccessToken.AccessToken) resp, err := client.PostForm(refreshTokenURL, formValues) if err != nil { diff --git a/util/helm/creds_test.go b/util/helm/creds_test.go index 90a232643e970..999f886ed04a9 100644 --- a/util/helm/creds_test.go +++ b/util/helm/creds_test.go @@ -7,6 +7,8 @@ import ( "testing" "time" + "github.com/golang-jwt/jwt/v5" + gocache "github.com/patrickmn/go-cache" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -79,7 +81,7 @@ func TestGetPasswordShouldGenerateTokenIfNotPresentInCache(t *testing.T) { defer mockServer.Close() workloadIdentityMock := new(mocks.TokenProvider) - workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return("accessToken", nil) + workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return(&workloadidentity.Token{AccessToken: "accessToken"}, nil) creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock) // Retrieve the token from the cache @@ -191,7 +193,7 @@ func TestGetAccessTokenAfterChallenge_Success(t *testing.T) { defer mockServer.Close() workloadIdentityMock := new(mocks.TokenProvider) - workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return("accessToken", nil) + workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return(&workloadidentity.Token{AccessToken: "accessToken"}, nil) creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock) tokenParams := map[string]string{ @@ -216,7 +218,7 @@ func TestGetAccessTokenAfterChallenge_Failure(t *testing.T) { // Create an instance of AzureWorkloadIdentityCreds with the mock credential wrapper workloadIdentityMock := new(mocks.TokenProvider) - workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return("accessToken", nil) + workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return(&workloadidentity.Token{AccessToken: "accessToken"}, nil) creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock) tokenParams := map[string]string{ @@ -241,7 +243,7 @@ func TestGetAccessTokenAfterChallenge_MalformedResponse(t *testing.T) { // Create an instance of AzureWorkloadIdentityCreds with the mock credential wrapper workloadIdentityMock := new(mocks.TokenProvider) - workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return("accessToken", nil) + workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return(&workloadidentity.Token{AccessToken: "accessToken"}, nil) creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock) tokenParams := map[string]string{ @@ -253,3 +255,125 @@ func TestGetAccessTokenAfterChallenge_MalformedResponse(t *testing.T) { require.ErrorContains(t, err, "failed to unmarshal response body") assert.Empty(t, refreshToken) } + +// Helper to generate a mock JWT token with a given expiry time +func generateMockJWT(expiry time.Time) (string, error) { + claims := jwt.MapClaims{ + "exp": expiry.Unix(), + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + // Use a dummy secret for signing + return token.SignedString([]byte("dummy-secret")) +} + +func TestGetAccessToken_FetchNewTokenIfExistingIsExpired(t *testing.T) { + resetAzureTokenCache() + accessToken1, _ := generateMockJWT(time.Now().Add(1 * time.Minute)) + accessToken2, _ := generateMockJWT(time.Now().Add(1 * time.Minute)) + + mockServerURL := "" + mockedServerURL := func() string { + return mockServerURL + } + + callCount := 0 + mockServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/v2/": + assert.Equal(t, "/v2/", r.URL.Path) + w.Header().Set("Www-Authenticate", fmt.Sprintf(`Bearer realm="%s",service="%s"`, mockedServerURL(), mockedServerURL()[8:])) + w.WriteHeader(http.StatusUnauthorized) + case "/oauth2/exchange": + assert.Equal(t, "/oauth2/exchange", r.URL.Path) + var response string + switch callCount { + case 0: + response = fmt.Sprintf(`{"refresh_token": "%s"}`, accessToken1) + case 1: + response = fmt.Sprintf(`{"refresh_token": "%s"}`, accessToken2) + default: + response = `{"refresh_token": "defaultToken"}` + } + callCount++ + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte(response)) + require.NoError(t, err) + default: + http.NotFound(w, r) + } + })) + defer mockServer.Close() + mockServerURL = mockServer.URL + + workloadIdentityMock := new(mocks.TokenProvider) + workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return(&workloadidentity.Token{AccessToken: "accessToken"}, nil) + creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock) + + refreshToken, err := creds.GetAccessToken() + require.NoError(t, err) + assert.Equal(t, accessToken1, refreshToken) + + time.Sleep(5 * time.Second) // Wait for the token to expire + + refreshToken, err = creds.GetAccessToken() + require.NoError(t, err) + assert.Equal(t, accessToken2, refreshToken) +} + +func TestGetAccessToken_ReuseTokenIfExistingIsNotExpired(t *testing.T) { + resetAzureTokenCache() + accessToken1, _ := generateMockJWT(time.Now().Add(6 * time.Minute)) + accessToken2, _ := generateMockJWT(time.Now().Add(1 * time.Minute)) + + mockServerURL := "" + mockedServerURL := func() string { + return mockServerURL + } + + callCount := 0 + mockServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/v2/": + assert.Equal(t, "/v2/", r.URL.Path) + w.Header().Set("Www-Authenticate", fmt.Sprintf(`Bearer realm="%s",service="%s"`, mockedServerURL(), mockedServerURL()[8:])) + w.WriteHeader(http.StatusUnauthorized) + case "/oauth2/exchange": + assert.Equal(t, "/oauth2/exchange", r.URL.Path) + var response string + switch callCount { + case 0: + response = fmt.Sprintf(`{"refresh_token": "%s"}`, accessToken1) + case 1: + response = fmt.Sprintf(`{"refresh_token": "%s"}`, accessToken2) + default: + response = `{"refresh_token": "defaultToken"}` + } + callCount++ + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte(response)) + require.NoError(t, err) + default: + http.NotFound(w, r) + } + })) + defer mockServer.Close() + mockServerURL = mockServer.URL + + workloadIdentityMock := new(mocks.TokenProvider) + workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return(&workloadidentity.Token{AccessToken: "accessToken"}, nil) + creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock) + + refreshToken, err := creds.GetAccessToken() + require.NoError(t, err) + assert.Equal(t, accessToken1, refreshToken) + + time.Sleep(5 * time.Second) // Wait for the token to expire + + refreshToken, err = creds.GetAccessToken() + require.NoError(t, err) + assert.Equal(t, accessToken1, refreshToken) +} + +func resetAzureTokenCache() { + azureTokenCache = gocache.New(gocache.NoExpiration, 0) +} diff --git a/util/workloadidentity/mocks/TokenProvider.go b/util/workloadidentity/mocks/TokenProvider.go index 883e76ade56ce..abfd993d34cba 100644 --- a/util/workloadidentity/mocks/TokenProvider.go +++ b/util/workloadidentity/mocks/TokenProvider.go @@ -2,51 +2,90 @@ package mocks -import mock "github.com/stretchr/testify/mock" +import ( + "github.com/argoproj/argo-cd/v3/util/workloadidentity" + mock "github.com/stretchr/testify/mock" +) + +// NewTokenProvider creates a new instance of TokenProvider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewTokenProvider(t interface { + mock.TestingT + Cleanup(func()) +}) *TokenProvider { + mock := &TokenProvider{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} // TokenProvider is an autogenerated mock type for the TokenProvider type type TokenProvider struct { mock.Mock } -// GetToken provides a mock function with given fields: scope -func (_m *TokenProvider) GetToken(scope string) (string, error) { - ret := _m.Called(scope) +type TokenProvider_Expecter struct { + mock *mock.Mock +} + +func (_m *TokenProvider) EXPECT() *TokenProvider_Expecter { + return &TokenProvider_Expecter{mock: &_m.Mock} +} + +// GetToken provides a mock function for the type TokenProvider +func (_mock *TokenProvider) GetToken(scope string) (*workloadidentity.Token, error) { + ret := _mock.Called(scope) if len(ret) == 0 { panic("no return value specified for GetToken") } - var r0 string + var r0 *workloadidentity.Token var r1 error - if rf, ok := ret.Get(0).(func(string) (string, error)); ok { - return rf(scope) + if returnFunc, ok := ret.Get(0).(func(string) (*workloadidentity.Token, error)); ok { + return returnFunc(scope) } - if rf, ok := ret.Get(0).(func(string) string); ok { - r0 = rf(scope) + if returnFunc, ok := ret.Get(0).(func(string) *workloadidentity.Token); ok { + r0 = returnFunc(scope) } else { - r0 = ret.Get(0).(string) + if ret.Get(0) != nil { + r0 = ret.Get(0).(*workloadidentity.Token) + } } - - if rf, ok := ret.Get(1).(func(string) error); ok { - r1 = rf(scope) + if returnFunc, ok := ret.Get(1).(func(string) error); ok { + r1 = returnFunc(scope) } else { r1 = ret.Error(1) } - return r0, r1 } -// NewTokenProvider creates a new instance of TokenProvider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewTokenProvider(t interface { - mock.TestingT - Cleanup(func()) -}) *TokenProvider { - mock := &TokenProvider{} - mock.Mock.Test(t) +// TokenProvider_GetToken_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetToken' +type TokenProvider_GetToken_Call struct { + *mock.Call +} - t.Cleanup(func() { mock.AssertExpectations(t) }) +// GetToken is a helper method to define mock.On call +// - scope +func (_e *TokenProvider_Expecter) GetToken(scope interface{}) *TokenProvider_GetToken_Call { + return &TokenProvider_GetToken_Call{Call: _e.mock.On("GetToken", scope)} +} - return mock +func (_c *TokenProvider_GetToken_Call) Run(run func(scope string)) *TokenProvider_GetToken_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *TokenProvider_GetToken_Call) Return(token *workloadidentity.Token, err error) *TokenProvider_GetToken_Call { + _c.Call.Return(token, err) + return _c +} + +func (_c *TokenProvider_GetToken_Call) RunAndReturn(run func(scope string) (*workloadidentity.Token, error)) *TokenProvider_GetToken_Call { + _c.Call.Return(run) + return _c } diff --git a/util/workloadidentity/workloadidentity.go b/util/workloadidentity/workloadidentity.go index 24aef2ffa5e64..08482b3ba6bdc 100644 --- a/util/workloadidentity/workloadidentity.go +++ b/util/workloadidentity/workloadidentity.go @@ -2,6 +2,7 @@ package workloadidentity import ( "context" + "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" @@ -12,8 +13,13 @@ const ( EmptyGuid = "00000000-0000-0000-0000-000000000000" //nolint:revive //FIXME(var-naming) ) +type Token struct { + AccessToken string + ExpiresOn time.Time +} + type TokenProvider interface { - GetToken(scope string) (string, error) + GetToken(scope string) (*Token, error) } type WorkloadIdentityTokenProvider struct { @@ -29,17 +35,23 @@ func NewWorkloadIdentityTokenProvider() TokenProvider { return WorkloadIdentityTokenProvider{tokenCredential: cred} } -func (c WorkloadIdentityTokenProvider) GetToken(scope string) (string, error) { +func (c WorkloadIdentityTokenProvider) GetToken(scope string) (*Token, error) { if initError != nil { - return "", initError + return nil, initError } token, err := c.tokenCredential.GetToken(context.Background(), policy.TokenRequestOptions{ Scopes: []string{scope}, }) if err != nil { - return "", err + return nil, err } - return token.Token, nil + return &Token{AccessToken: token.Token, ExpiresOn: token.ExpiresOn}, nil +} + +func CalculateCacheExpiryBasedOnTokenExpiry(tokenExpiry time.Time) time.Duration { + // Calculate the cache expiry as 5 minutes before the token expires + cacheExpiry := time.Until(tokenExpiry) - time.Minute*5 + return cacheExpiry } diff --git a/util/workloadidentity/workloadidentity_test.go b/util/workloadidentity/workloadidentity_test.go index 6e16f8d82926d..1dc416b8e9f12 100644 --- a/util/workloadidentity/workloadidentity_test.go +++ b/util/workloadidentity/workloadidentity_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "testing" + "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" @@ -37,7 +38,7 @@ func TestGetToken_Success(t *testing.T) { token, err := provider.GetToken(scope) require.NoError(t, err, "Expected no error from GetToken") - assert.Equal(t, "mocked_token", token, "Expected token to match") + assert.Equal(t, "mocked_token", token.AccessToken, "Expected token to match") } func TestGetToken_Failure(t *testing.T) { @@ -47,7 +48,7 @@ func TestGetToken_Failure(t *testing.T) { token, err := provider.GetToken(scope) require.Error(t, err, "Expected error from GetToken") - assert.Empty(t, token, "Expected token to be empty on error") + assert.Nil(t, token, "Expected token to be empty on error") } func TestGetToken_InitError(t *testing.T) { @@ -56,5 +57,58 @@ func TestGetToken_InitError(t *testing.T) { token, err := provider.GetToken("https://management.core.windows.net/.default") require.Error(t, err, "Expected error from GetToken due to initialization error") - assert.Empty(t, token, "Expected token to be empty on initialization error") + assert.Nil(t, token, "Expected token to be empty on initialization error") +} + +func TestCalculateCacheExpiryBasedOnTokenExpiry(t *testing.T) { + now := time.Now() + + tests := []struct { + name string + expiry time.Time + expected time.Duration + delta float64 + }{ + { + name: "Future expiry (10min ahead)", + expiry: now.Add(10 * time.Minute), + expected: 5 * time.Minute, + delta: 10, // allow 10s difference + }, + { + name: "Expiring in 5 minutes", + expiry: now.Add(5 * time.Second), + expected: now.Sub(now.Add(5 * time.Minute)), + delta: 10, // allow 10s difference + }, + { + name: "Expires soon (4min ahead)", + expiry: now.Add(4 * time.Minute), + expected: now.Sub(now.Add(1 * time.Minute)), + delta: 10, // allow 10s difference + }, + { + name: "Just expired (1s ago)", + expiry: now.Add(-1 * time.Second), + expected: now.Sub(now.Add(5 * time.Minute)), + delta: 10, // allow 10s difference + }, + { + name: "Already expired (1m ago)", + expiry: now.Add(-1 * time.Minute), + expected: now.Sub(now.Add(6 * time.Minute)), + delta: 10, // allow 10s difference + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + actual := CalculateCacheExpiryBasedOnTokenExpiry(tt.expiry) + if tt.delta > 0 { + assert.InDelta(t, tt.expected.Seconds(), actual.Seconds(), tt.delta) + } else { + assert.Equal(t, tt.expected, actual) + } + }) + } } From 3ee5c883924601d5e191657741f85fc647ffafd5 Mon Sep 17 00:00:00 2001 From: Alexandre Gaudreault Date: Wed, 4 Jun 2025 12:36:12 -0400 Subject: [PATCH 2/2] fix mock Signed-off-by: Alexandre Gaudreault --- util/workloadidentity/mocks/TokenProvider.go | 78 ++++++-------------- 1 file changed, 22 insertions(+), 56 deletions(-) diff --git a/util/workloadidentity/mocks/TokenProvider.go b/util/workloadidentity/mocks/TokenProvider.go index abfd993d34cba..9e333a44cbade 100644 --- a/util/workloadidentity/mocks/TokenProvider.go +++ b/util/workloadidentity/mocks/TokenProvider.go @@ -3,40 +3,18 @@ package mocks import ( - "github.com/argoproj/argo-cd/v3/util/workloadidentity" + workloadidentity "github.com/argoproj/argo-cd/v3/util/workloadidentity" mock "github.com/stretchr/testify/mock" ) -// NewTokenProvider creates a new instance of TokenProvider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewTokenProvider(t interface { - mock.TestingT - Cleanup(func()) -}) *TokenProvider { - mock := &TokenProvider{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} - // TokenProvider is an autogenerated mock type for the TokenProvider type type TokenProvider struct { mock.Mock } -type TokenProvider_Expecter struct { - mock *mock.Mock -} - -func (_m *TokenProvider) EXPECT() *TokenProvider_Expecter { - return &TokenProvider_Expecter{mock: &_m.Mock} -} - -// GetToken provides a mock function for the type TokenProvider -func (_mock *TokenProvider) GetToken(scope string) (*workloadidentity.Token, error) { - ret := _mock.Called(scope) +// GetToken provides a mock function with given fields: scope +func (_m *TokenProvider) GetToken(scope string) (*workloadidentity.Token, error) { + ret := _m.Called(scope) if len(ret) == 0 { panic("no return value specified for GetToken") @@ -44,48 +22,36 @@ func (_mock *TokenProvider) GetToken(scope string) (*workloadidentity.Token, err var r0 *workloadidentity.Token var r1 error - if returnFunc, ok := ret.Get(0).(func(string) (*workloadidentity.Token, error)); ok { - return returnFunc(scope) + if rf, ok := ret.Get(0).(func(string) (*workloadidentity.Token, error)); ok { + return rf(scope) } - if returnFunc, ok := ret.Get(0).(func(string) *workloadidentity.Token); ok { - r0 = returnFunc(scope) + if rf, ok := ret.Get(0).(func(string) *workloadidentity.Token); ok { + r0 = rf(scope) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*workloadidentity.Token) } } - if returnFunc, ok := ret.Get(1).(func(string) error); ok { - r1 = returnFunc(scope) + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(scope) } else { r1 = ret.Error(1) } - return r0, r1 -} - -// TokenProvider_GetToken_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetToken' -type TokenProvider_GetToken_Call struct { - *mock.Call -} -// GetToken is a helper method to define mock.On call -// - scope -func (_e *TokenProvider_Expecter) GetToken(scope interface{}) *TokenProvider_GetToken_Call { - return &TokenProvider_GetToken_Call{Call: _e.mock.On("GetToken", scope)} + return r0, r1 } -func (_c *TokenProvider_GetToken_Call) Run(run func(scope string)) *TokenProvider_GetToken_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string)) - }) - return _c -} +// NewTokenProvider creates a new instance of TokenProvider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewTokenProvider(t interface { + mock.TestingT + Cleanup(func()) +}) *TokenProvider { + mock := &TokenProvider{} + mock.Mock.Test(t) -func (_c *TokenProvider_GetToken_Call) Return(token *workloadidentity.Token, err error) *TokenProvider_GetToken_Call { - _c.Call.Return(token, err) - return _c -} + t.Cleanup(func() { mock.AssertExpectations(t) }) -func (_c *TokenProvider_GetToken_Call) RunAndReturn(run func(scope string) (*workloadidentity.Token, error)) *TokenProvider_GetToken_Call { - _c.Call.Return(run) - return _c + return mock }