Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion util/git/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/argoproj/argo-cd/v3/util/workloadidentity"
"github.com/argoproj/argo-cd/v3/util/workloadidentity/mocks"
)

Expand Down Expand Up @@ -847,7 +848,7 @@ func Test_nativeGitClient_CommitAndPush(t *testing.T) {

func Test_newAuth_AzureWorkloadIdentity(t *testing.T) {
tokenprovider := new(mocks.TokenProvider)
tokenprovider.On("GetToken", azureDevopsEntraResourceId).Return("accessToken", nil)
tokenprovider.On("GetToken", azureDevopsEntraResourceId).Return(&workloadidentity.Token{AccessToken: "accessToken"}, nil)

creds := AzureWorkloadIdentityCreds{store: NoopCredsStore{}, tokenProvider: tokenprovider}

Expand Down
9 changes: 6 additions & 3 deletions util/git/creds.go
Original file line number Diff line number Diff line change
Expand Up @@ -735,16 +735,19 @@ func (creds AzureWorkloadIdentityCreds) getAccessToken(scope string) (string, er

t, found := azureTokenCache.Get(key)
if found {
return t.(string), nil
return t.(*workloadidentity.Token).AccessToken, nil
}

token, err := creds.tokenProvider.GetToken(scope)
if err != nil {
return "", fmt.Errorf("failed to get Azure access token: %w", err)
}

azureTokenCache.Set(key, token, 2*time.Hour)
return token, nil
cacheExpiry := workloadidentity.CalculateCacheExpiryBasedOnTokenExpiry(token.ExpiresOn)
if cacheExpiry > 0 {
azureTokenCache.Set(key, token, cacheExpiry)
}
return token.AccessToken, nil
}

func (creds AzureWorkloadIdentityCreds) GetAzureDevOpsAccessToken() (string, error) {
Expand Down
53 changes: 50 additions & 3 deletions util/git/creds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ import (
"regexp"
"strings"
"testing"
"time"

"github.com/google/uuid"
gocache "github.com/patrickmn/go-cache"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
Expand Down Expand Up @@ -412,9 +414,10 @@ func TestGoogleCloudCreds_Environ_cleanup(t *testing.T) {
}

func TestAzureWorkloadIdentityCreds_Environ(t *testing.T) {
resetAzureTokenCache()
store := &memoryCredsStore{creds: make(map[string]cred)}
workloadIdentityMock := new(mocks.TokenProvider)
workloadIdentityMock.On("GetToken", azureDevopsEntraResourceId).Return("accessToken", nil)
workloadIdentityMock.On("GetToken", azureDevopsEntraResourceId).Return(&workloadidentity.Token{AccessToken: "accessToken", ExpiresOn: time.Now().Add(time.Minute)}, nil)
creds := AzureWorkloadIdentityCreds{store, workloadIdentityMock}
_, _, err := creds.Environ()
require.NoError(t, err)
Expand All @@ -427,9 +430,10 @@ func TestAzureWorkloadIdentityCreds_Environ(t *testing.T) {
}

func TestAzureWorkloadIdentityCreds_Environ_cleanup(t *testing.T) {
resetAzureTokenCache()
store := &memoryCredsStore{creds: make(map[string]cred)}
workloadIdentityMock := new(mocks.TokenProvider)
workloadIdentityMock.On("GetToken", azureDevopsEntraResourceId).Return("accessToken", nil)
workloadIdentityMock.On("GetToken", azureDevopsEntraResourceId).Return(&workloadidentity.Token{AccessToken: "accessToken", ExpiresOn: time.Now().Add(time.Minute)}, nil)
creds := AzureWorkloadIdentityCreds{store, workloadIdentityMock}
closer, _, err := creds.Environ()
require.NoError(t, err)
Expand All @@ -439,9 +443,10 @@ func TestAzureWorkloadIdentityCreds_Environ_cleanup(t *testing.T) {
}

func TestAzureWorkloadIdentityCreds_GetUserInfo(t *testing.T) {
resetAzureTokenCache()
store := &memoryCredsStore{creds: make(map[string]cred)}
workloadIdentityMock := new(mocks.TokenProvider)
workloadIdentityMock.On("GetToken", azureDevopsEntraResourceId).Return("accessToken", nil)
workloadIdentityMock.On("GetToken", azureDevopsEntraResourceId).Return(&workloadidentity.Token{AccessToken: "accessToken", ExpiresOn: time.Now().Add(time.Minute)}, nil)
creds := AzureWorkloadIdentityCreds{store, workloadIdentityMock}

user, email, err := creds.GetUserInfo(t.Context())
Expand All @@ -456,3 +461,45 @@ func TestGetHelmCredsShouldReturnHelmCredsIfAzureWorkloadIdentityNotSpecified(t
_, ok := creds.(AzureWorkloadIdentityCreds)
require.Truef(t, ok, "expected HelmCreds but got %T", creds)
}

func TestAzureWorkloadIdentityCreds_FetchNewTokenIfExistingIsExpired(t *testing.T) {
resetAzureTokenCache()
store := &memoryCredsStore{creds: make(map[string]cred)}
workloadIdentityMock := new(mocks.TokenProvider)
workloadIdentityMock.On("GetToken", azureDevopsEntraResourceId).
Return(&workloadidentity.Token{AccessToken: "firstToken", ExpiresOn: time.Now().Add(time.Minute)}, nil).Once()
workloadIdentityMock.On("GetToken", azureDevopsEntraResourceId).
Return(&workloadidentity.Token{AccessToken: "secondToken"}, nil).Once()
creds := AzureWorkloadIdentityCreds{store, workloadIdentityMock}
token, err := creds.GetAzureDevOpsAccessToken()
require.NoError(t, err)

assert.Equal(t, "firstToken", token)
time.Sleep(5 * time.Second)
token, err = creds.GetAzureDevOpsAccessToken()
require.NoError(t, err)
assert.Equal(t, "secondToken", token)
}

func TestAzureWorkloadIdentityCreds_ReuseTokenIfExistingIsNotExpired(t *testing.T) {
resetAzureTokenCache()
store := &memoryCredsStore{creds: make(map[string]cred)}
workloadIdentityMock := new(mocks.TokenProvider)
firstToken := &workloadidentity.Token{AccessToken: "firstToken", ExpiresOn: time.Now().Add(6 * time.Minute)}
secondToken := &workloadidentity.Token{AccessToken: "secondToken"}
workloadIdentityMock.On("GetToken", azureDevopsEntraResourceId).Return(firstToken, nil).Once()
workloadIdentityMock.On("GetToken", azureDevopsEntraResourceId).Return(secondToken, nil).Once()
creds := AzureWorkloadIdentityCreds{store, workloadIdentityMock}
token, err := creds.GetAzureDevOpsAccessToken()
require.NoError(t, err)

assert.Equal(t, "firstToken", token)
time.Sleep(5 * time.Second)
token, err = creds.GetAzureDevOpsAccessToken()
require.NoError(t, err)
assert.Equal(t, "firstToken", token)
}

func resetAzureTokenCache() {
azureTokenCache = gocache.New(gocache.NoExpiration, 0)
}
3 changes: 2 additions & 1 deletion util/helm/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"gopkg.in/yaml.v2"

"github.com/argoproj/argo-cd/v3/util/io"
"github.com/argoproj/argo-cd/v3/util/workloadidentity"
"github.com/argoproj/argo-cd/v3/util/workloadidentity/mocks"
)

Expand Down Expand Up @@ -300,7 +301,7 @@ func TestGetTagsFromURLPrivateRepoWithAzureWorkloadIdentityAuthentication(t *tes
}

workloadIdentityMock := new(mocks.TokenProvider)
workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return("accessToken", nil)
workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return(&workloadidentity.Token{AccessToken: "accessToken"}, nil)

mockServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("called %s", r.URL.Path)
Expand Down
30 changes: 27 additions & 3 deletions util/helm/creds.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ import (
"strings"
"time"

"github.com/golang-jwt/jwt/v5"
gocache "github.com/patrickmn/go-cache"
log "github.com/sirupsen/logrus"

argoutils "github.com/argoproj/argo-cd/v3/util"
"github.com/argoproj/argo-cd/v3/util/env"
Expand Down Expand Up @@ -146,11 +148,33 @@ func (creds AzureWorkloadIdentityCreds) GetAccessToken() (string, error) {
return "", fmt.Errorf("failed to get Azure access token after challenge: %w", err)
}

// Access token has a lifetime of 3 hours
storeAzureToken(key, token, 2*time.Hour)
tokenExpiry, err := getJWTExpiry(token)
if err != nil {
log.Warnf("failed to get token expiry from JWT: %v, using current time as fallback", err)
tokenExpiry = time.Now()
}

cacheExpiry := workloadidentity.CalculateCacheExpiryBasedOnTokenExpiry(tokenExpiry)
if cacheExpiry > 0 {
storeAzureToken(key, token, cacheExpiry)
}
return token, nil
}

func getJWTExpiry(token string) (time.Time, error) {
parser := jwt.NewParser()
claims := jwt.MapClaims{}
_, _, err := parser.ParseUnverified(token, claims)
if err != nil {
return time.Time{}, fmt.Errorf("failed to parse JWT: %w", err)
}
exp, err := claims.GetExpirationTime()
if err != nil {
return time.Time{}, fmt.Errorf("'exp' claim not found or invalid in token: %w", err)
}
return time.UnixMilli(exp.UnixMilli()), nil
}

func (creds AzureWorkloadIdentityCreds) getAccessTokenAfterChallenge(tokenParams map[string]string) (string, error) {
realm := tokenParams["realm"]
service := tokenParams["service"]
Expand All @@ -177,7 +201,7 @@ func (creds AzureWorkloadIdentityCreds) getAccessTokenAfterChallenge(tokenParams
formValues := url.Values{}
formValues.Add("grant_type", "access_token")
formValues.Add("service", service)
formValues.Add("access_token", armAccessToken)
formValues.Add("access_token", armAccessToken.AccessToken)

resp, err := client.PostForm(refreshTokenURL, formValues)
if err != nil {
Expand Down
132 changes: 128 additions & 4 deletions util/helm/creds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"testing"
"time"

"github.com/golang-jwt/jwt/v5"
gocache "github.com/patrickmn/go-cache"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -79,7 +81,7 @@ func TestGetPasswordShouldGenerateTokenIfNotPresentInCache(t *testing.T) {
defer mockServer.Close()

workloadIdentityMock := new(mocks.TokenProvider)
workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return("accessToken", nil)
workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return(&workloadidentity.Token{AccessToken: "accessToken"}, nil)
creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock)

// Retrieve the token from the cache
Expand Down Expand Up @@ -191,7 +193,7 @@ func TestGetAccessTokenAfterChallenge_Success(t *testing.T) {
defer mockServer.Close()

workloadIdentityMock := new(mocks.TokenProvider)
workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return("accessToken", nil)
workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return(&workloadidentity.Token{AccessToken: "accessToken"}, nil)
creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock)

tokenParams := map[string]string{
Expand All @@ -216,7 +218,7 @@ func TestGetAccessTokenAfterChallenge_Failure(t *testing.T) {

// Create an instance of AzureWorkloadIdentityCreds with the mock credential wrapper
workloadIdentityMock := new(mocks.TokenProvider)
workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return("accessToken", nil)
workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return(&workloadidentity.Token{AccessToken: "accessToken"}, nil)
creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock)

tokenParams := map[string]string{
Expand All @@ -241,7 +243,7 @@ func TestGetAccessTokenAfterChallenge_MalformedResponse(t *testing.T) {

// Create an instance of AzureWorkloadIdentityCreds with the mock credential wrapper
workloadIdentityMock := new(mocks.TokenProvider)
workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return("accessToken", nil)
workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return(&workloadidentity.Token{AccessToken: "accessToken"}, nil)
creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock)

tokenParams := map[string]string{
Expand All @@ -253,3 +255,125 @@ func TestGetAccessTokenAfterChallenge_MalformedResponse(t *testing.T) {
require.ErrorContains(t, err, "failed to unmarshal response body")
assert.Empty(t, refreshToken)
}

// Helper to generate a mock JWT token with a given expiry time
func generateMockJWT(expiry time.Time) (string, error) {
claims := jwt.MapClaims{
"exp": expiry.Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
// Use a dummy secret for signing
return token.SignedString([]byte("dummy-secret"))
}

func TestGetAccessToken_FetchNewTokenIfExistingIsExpired(t *testing.T) {
resetAzureTokenCache()
accessToken1, _ := generateMockJWT(time.Now().Add(1 * time.Minute))
accessToken2, _ := generateMockJWT(time.Now().Add(1 * time.Minute))

mockServerURL := ""
mockedServerURL := func() string {
return mockServerURL
}

callCount := 0
mockServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/v2/":
assert.Equal(t, "/v2/", r.URL.Path)
w.Header().Set("Www-Authenticate", fmt.Sprintf(`Bearer realm="%s",service="%s"`, mockedServerURL(), mockedServerURL()[8:]))
w.WriteHeader(http.StatusUnauthorized)
case "/oauth2/exchange":
assert.Equal(t, "/oauth2/exchange", r.URL.Path)
var response string
switch callCount {
case 0:
response = fmt.Sprintf(`{"refresh_token": "%s"}`, accessToken1)
case 1:
response = fmt.Sprintf(`{"refresh_token": "%s"}`, accessToken2)
default:
response = `{"refresh_token": "defaultToken"}`
}
callCount++
w.WriteHeader(http.StatusOK)
_, err := w.Write([]byte(response))
require.NoError(t, err)
default:
http.NotFound(w, r)
}
}))
defer mockServer.Close()
mockServerURL = mockServer.URL

workloadIdentityMock := new(mocks.TokenProvider)
workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return(&workloadidentity.Token{AccessToken: "accessToken"}, nil)
creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock)

refreshToken, err := creds.GetAccessToken()
require.NoError(t, err)
assert.Equal(t, accessToken1, refreshToken)

time.Sleep(5 * time.Second) // Wait for the token to expire

refreshToken, err = creds.GetAccessToken()
require.NoError(t, err)
assert.Equal(t, accessToken2, refreshToken)
}

func TestGetAccessToken_ReuseTokenIfExistingIsNotExpired(t *testing.T) {
resetAzureTokenCache()
accessToken1, _ := generateMockJWT(time.Now().Add(6 * time.Minute))
accessToken2, _ := generateMockJWT(time.Now().Add(1 * time.Minute))

mockServerURL := ""
mockedServerURL := func() string {
return mockServerURL
}

callCount := 0
mockServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/v2/":
assert.Equal(t, "/v2/", r.URL.Path)
w.Header().Set("Www-Authenticate", fmt.Sprintf(`Bearer realm="%s",service="%s"`, mockedServerURL(), mockedServerURL()[8:]))
w.WriteHeader(http.StatusUnauthorized)
case "/oauth2/exchange":
assert.Equal(t, "/oauth2/exchange", r.URL.Path)
var response string
switch callCount {
case 0:
response = fmt.Sprintf(`{"refresh_token": "%s"}`, accessToken1)
case 1:
response = fmt.Sprintf(`{"refresh_token": "%s"}`, accessToken2)
default:
response = `{"refresh_token": "defaultToken"}`
}
callCount++
w.WriteHeader(http.StatusOK)
_, err := w.Write([]byte(response))
require.NoError(t, err)
default:
http.NotFound(w, r)
}
}))
defer mockServer.Close()
mockServerURL = mockServer.URL

workloadIdentityMock := new(mocks.TokenProvider)
workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return(&workloadidentity.Token{AccessToken: "accessToken"}, nil)
creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock)

refreshToken, err := creds.GetAccessToken()
require.NoError(t, err)
assert.Equal(t, accessToken1, refreshToken)

time.Sleep(5 * time.Second) // Wait for the token to expire

refreshToken, err = creds.GetAccessToken()
require.NoError(t, err)
assert.Equal(t, accessToken1, refreshToken)
}

func resetAzureTokenCache() {
azureTokenCache = gocache.New(gocache.NoExpiration, 0)
}
Loading
Loading