77 "testing"
88 "time"
99
10+ "github.com/golang-jwt/jwt/v5"
11+ gocache "github.com/patrickmn/go-cache"
1012 "github.com/stretchr/testify/assert"
1113 "github.com/stretchr/testify/require"
1214
@@ -79,7 +81,7 @@ func TestGetPasswordShouldGenerateTokenIfNotPresentInCache(t *testing.T) {
7981 defer mockServer .Close ()
8082
8183 workloadIdentityMock := new (mocks.TokenProvider )
82- workloadIdentityMock .On ("GetToken" , "https://management.core.windows.net/.default" ).Return ("accessToken" , nil )
84+ workloadIdentityMock .On ("GetToken" , "https://management.core.windows.net/.default" ).Return (& workloadidentity. Token { AccessToken : "accessToken" } , nil )
8385 creds := NewAzureWorkloadIdentityCreds (mockServer .URL [8 :], "" , nil , nil , true , workloadIdentityMock )
8486
8587 // Retrieve the token from the cache
@@ -191,7 +193,7 @@ func TestGetAccessTokenAfterChallenge_Success(t *testing.T) {
191193 defer mockServer .Close ()
192194
193195 workloadIdentityMock := new (mocks.TokenProvider )
194- workloadIdentityMock .On ("GetToken" , "https://management.core.windows.net/.default" ).Return ("accessToken" , nil )
196+ workloadIdentityMock .On ("GetToken" , "https://management.core.windows.net/.default" ).Return (& workloadidentity. Token { AccessToken : "accessToken" } , nil )
195197 creds := NewAzureWorkloadIdentityCreds (mockServer .URL [8 :], "" , nil , nil , true , workloadIdentityMock )
196198
197199 tokenParams := map [string ]string {
@@ -216,7 +218,7 @@ func TestGetAccessTokenAfterChallenge_Failure(t *testing.T) {
216218
217219 // Create an instance of AzureWorkloadIdentityCreds with the mock credential wrapper
218220 workloadIdentityMock := new (mocks.TokenProvider )
219- workloadIdentityMock .On ("GetToken" , "https://management.core.windows.net/.default" ).Return ("accessToken" , nil )
221+ workloadIdentityMock .On ("GetToken" , "https://management.core.windows.net/.default" ).Return (& workloadidentity. Token { AccessToken : "accessToken" } , nil )
220222 creds := NewAzureWorkloadIdentityCreds (mockServer .URL [8 :], "" , nil , nil , true , workloadIdentityMock )
221223
222224 tokenParams := map [string ]string {
@@ -241,7 +243,7 @@ func TestGetAccessTokenAfterChallenge_MalformedResponse(t *testing.T) {
241243
242244 // Create an instance of AzureWorkloadIdentityCreds with the mock credential wrapper
243245 workloadIdentityMock := new (mocks.TokenProvider )
244- workloadIdentityMock .On ("GetToken" , "https://management.core.windows.net/.default" ).Return ("accessToken" , nil )
246+ workloadIdentityMock .On ("GetToken" , "https://management.core.windows.net/.default" ).Return (& workloadidentity. Token { AccessToken : "accessToken" } , nil )
245247 creds := NewAzureWorkloadIdentityCreds (mockServer .URL [8 :], "" , nil , nil , true , workloadIdentityMock )
246248
247249 tokenParams := map [string ]string {
@@ -253,3 +255,125 @@ func TestGetAccessTokenAfterChallenge_MalformedResponse(t *testing.T) {
253255 require .ErrorContains (t , err , "failed to unmarshal response body" )
254256 assert .Empty (t , refreshToken )
255257}
258+
259+ // Helper to generate a mock JWT token with a given expiry time
260+ func generateMockJWT (expiry time.Time ) (string , error ) {
261+ claims := jwt.MapClaims {
262+ "exp" : expiry .Unix (),
263+ }
264+ token := jwt .NewWithClaims (jwt .SigningMethodHS256 , claims )
265+ // Use a dummy secret for signing
266+ return token .SignedString ([]byte ("dummy-secret" ))
267+ }
268+
269+ func TestGetAccessToken_FetchNewTokenIfExistingIsExpired (t * testing.T ) {
270+ resetAzureTokenCache ()
271+ accessToken1 , _ := generateMockJWT (time .Now ().Add (1 * time .Minute ))
272+ accessToken2 , _ := generateMockJWT (time .Now ().Add (1 * time .Minute ))
273+
274+ mockServerURL := ""
275+ mockedServerURL := func () string {
276+ return mockServerURL
277+ }
278+
279+ callCount := 0
280+ mockServer := httptest .NewTLSServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
281+ switch r .URL .Path {
282+ case "/v2/" :
283+ assert .Equal (t , "/v2/" , r .URL .Path )
284+ w .Header ().Set ("Www-Authenticate" , fmt .Sprintf (`Bearer realm="%s",service="%s"` , mockedServerURL (), mockedServerURL ()[8 :]))
285+ w .WriteHeader (http .StatusUnauthorized )
286+ case "/oauth2/exchange" :
287+ assert .Equal (t , "/oauth2/exchange" , r .URL .Path )
288+ var response string
289+ switch callCount {
290+ case 0 :
291+ response = fmt .Sprintf (`{"refresh_token": "%s"}` , accessToken1 )
292+ case 1 :
293+ response = fmt .Sprintf (`{"refresh_token": "%s"}` , accessToken2 )
294+ default :
295+ response = `{"refresh_token": "defaultToken"}`
296+ }
297+ callCount ++
298+ w .WriteHeader (http .StatusOK )
299+ _ , err := w .Write ([]byte (response ))
300+ require .NoError (t , err )
301+ default :
302+ http .NotFound (w , r )
303+ }
304+ }))
305+ defer mockServer .Close ()
306+ mockServerURL = mockServer .URL
307+
308+ workloadIdentityMock := new (mocks.TokenProvider )
309+ workloadIdentityMock .On ("GetToken" , "https://management.core.windows.net/.default" ).Return (& workloadidentity.Token {AccessToken : "accessToken" }, nil )
310+ creds := NewAzureWorkloadIdentityCreds (mockServer .URL [8 :], "" , nil , nil , true , workloadIdentityMock )
311+
312+ refreshToken , err := creds .GetAccessToken ()
313+ require .NoError (t , err )
314+ assert .Equal (t , accessToken1 , refreshToken )
315+
316+ time .Sleep (5 * time .Second ) // Wait for the token to expire
317+
318+ refreshToken , err = creds .GetAccessToken ()
319+ require .NoError (t , err )
320+ assert .Equal (t , accessToken2 , refreshToken )
321+ }
322+
323+ func TestGetAccessToken_ReuseTokenIfExistingIsNotExpired (t * testing.T ) {
324+ resetAzureTokenCache ()
325+ accessToken1 , _ := generateMockJWT (time .Now ().Add (6 * time .Minute ))
326+ accessToken2 , _ := generateMockJWT (time .Now ().Add (1 * time .Minute ))
327+
328+ mockServerURL := ""
329+ mockedServerURL := func () string {
330+ return mockServerURL
331+ }
332+
333+ callCount := 0
334+ mockServer := httptest .NewTLSServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
335+ switch r .URL .Path {
336+ case "/v2/" :
337+ assert .Equal (t , "/v2/" , r .URL .Path )
338+ w .Header ().Set ("Www-Authenticate" , fmt .Sprintf (`Bearer realm="%s",service="%s"` , mockedServerURL (), mockedServerURL ()[8 :]))
339+ w .WriteHeader (http .StatusUnauthorized )
340+ case "/oauth2/exchange" :
341+ assert .Equal (t , "/oauth2/exchange" , r .URL .Path )
342+ var response string
343+ switch callCount {
344+ case 0 :
345+ response = fmt .Sprintf (`{"refresh_token": "%s"}` , accessToken1 )
346+ case 1 :
347+ response = fmt .Sprintf (`{"refresh_token": "%s"}` , accessToken2 )
348+ default :
349+ response = `{"refresh_token": "defaultToken"}`
350+ }
351+ callCount ++
352+ w .WriteHeader (http .StatusOK )
353+ _ , err := w .Write ([]byte (response ))
354+ require .NoError (t , err )
355+ default :
356+ http .NotFound (w , r )
357+ }
358+ }))
359+ defer mockServer .Close ()
360+ mockServerURL = mockServer .URL
361+
362+ workloadIdentityMock := new (mocks.TokenProvider )
363+ workloadIdentityMock .On ("GetToken" , "https://management.core.windows.net/.default" ).Return (& workloadidentity.Token {AccessToken : "accessToken" }, nil )
364+ creds := NewAzureWorkloadIdentityCreds (mockServer .URL [8 :], "" , nil , nil , true , workloadIdentityMock )
365+
366+ refreshToken , err := creds .GetAccessToken ()
367+ require .NoError (t , err )
368+ assert .Equal (t , accessToken1 , refreshToken )
369+
370+ time .Sleep (5 * time .Second ) // Wait for the token to expire
371+
372+ refreshToken , err = creds .GetAccessToken ()
373+ require .NoError (t , err )
374+ assert .Equal (t , accessToken1 , refreshToken )
375+ }
376+
377+ func resetAzureTokenCache () {
378+ azureTokenCache = gocache .New (gocache .NoExpiration , 0 )
379+ }
0 commit comments