77	"testing" 
88	"time" 
99
10+ 	"github.com/golang-jwt/jwt/v5" 
11+ 	gocache "github.com/patrickmn/go-cache" 
1012	"github.com/stretchr/testify/assert" 
1113	"github.com/stretchr/testify/require" 
1214
@@ -79,7 +81,7 @@ func TestGetPasswordShouldGenerateTokenIfNotPresentInCache(t *testing.T) {
7981	defer  mockServer .Close ()
8082
8183	workloadIdentityMock  :=  new (mocks.TokenProvider )
82- 	workloadIdentityMock .On ("GetToken" , "https://management.core.windows.net/.default" ).Return ("accessToken" , nil )
84+ 	workloadIdentityMock .On ("GetToken" , "https://management.core.windows.net/.default" ).Return (& workloadidentity. Token { AccessToken :  "accessToken" } , nil )
8385	creds  :=  NewAzureWorkloadIdentityCreds (mockServer .URL [8 :], "" , nil , nil , true , workloadIdentityMock )
8486
8587	// Retrieve the token from the cache 
@@ -191,7 +193,7 @@ func TestGetAccessTokenAfterChallenge_Success(t *testing.T) {
191193	defer  mockServer .Close ()
192194
193195	workloadIdentityMock  :=  new (mocks.TokenProvider )
194- 	workloadIdentityMock .On ("GetToken" , "https://management.core.windows.net/.default" ).Return ("accessToken" , nil )
196+ 	workloadIdentityMock .On ("GetToken" , "https://management.core.windows.net/.default" ).Return (& workloadidentity. Token { AccessToken :  "accessToken" } , nil )
195197	creds  :=  NewAzureWorkloadIdentityCreds (mockServer .URL [8 :], "" , nil , nil , true , workloadIdentityMock )
196198
197199	tokenParams  :=  map [string ]string {
@@ -216,7 +218,7 @@ func TestGetAccessTokenAfterChallenge_Failure(t *testing.T) {
216218
217219	// Create an instance of AzureWorkloadIdentityCreds with the mock credential wrapper 
218220	workloadIdentityMock  :=  new (mocks.TokenProvider )
219- 	workloadIdentityMock .On ("GetToken" , "https://management.core.windows.net/.default" ).Return ("accessToken" , nil )
221+ 	workloadIdentityMock .On ("GetToken" , "https://management.core.windows.net/.default" ).Return (& workloadidentity. Token { AccessToken :  "accessToken" } , nil )
220222	creds  :=  NewAzureWorkloadIdentityCreds (mockServer .URL [8 :], "" , nil , nil , true , workloadIdentityMock )
221223
222224	tokenParams  :=  map [string ]string {
@@ -241,7 +243,7 @@ func TestGetAccessTokenAfterChallenge_MalformedResponse(t *testing.T) {
241243
242244	// Create an instance of AzureWorkloadIdentityCreds with the mock credential wrapper 
243245	workloadIdentityMock  :=  new (mocks.TokenProvider )
244- 	workloadIdentityMock .On ("GetToken" , "https://management.core.windows.net/.default" ).Return ("accessToken" , nil )
246+ 	workloadIdentityMock .On ("GetToken" , "https://management.core.windows.net/.default" ).Return (& workloadidentity. Token { AccessToken :  "accessToken" } , nil )
245247	creds  :=  NewAzureWorkloadIdentityCreds (mockServer .URL [8 :], "" , nil , nil , true , workloadIdentityMock )
246248
247249	tokenParams  :=  map [string ]string {
@@ -253,3 +255,125 @@ func TestGetAccessTokenAfterChallenge_MalformedResponse(t *testing.T) {
253255	require .ErrorContains (t , err , "failed to unmarshal response body" )
254256	assert .Empty (t , refreshToken )
255257}
258+ 
259+ // Helper to generate a mock JWT token with a given expiry time 
260+ func  generateMockJWT (expiry  time.Time ) (string , error ) {
261+ 	claims  :=  jwt.MapClaims {
262+ 		"exp" : expiry .Unix (),
263+ 	}
264+ 	token  :=  jwt .NewWithClaims (jwt .SigningMethodHS256 , claims )
265+ 	// Use a dummy secret for signing 
266+ 	return  token .SignedString ([]byte ("dummy-secret" ))
267+ }
268+ 
269+ func  TestGetAccessToken_FetchNewTokenIfExistingIsExpired (t  * testing.T ) {
270+ 	resetAzureTokenCache ()
271+ 	accessToken1 , _  :=  generateMockJWT (time .Now ().Add (1  *  time .Minute ))
272+ 	accessToken2 , _  :=  generateMockJWT (time .Now ().Add (1  *  time .Minute ))
273+ 
274+ 	mockServerURL  :=  "" 
275+ 	mockedServerURL  :=  func () string  {
276+ 		return  mockServerURL 
277+ 	}
278+ 
279+ 	callCount  :=  0 
280+ 	mockServer  :=  httptest .NewTLSServer (http .HandlerFunc (func (w  http.ResponseWriter , r  * http.Request ) {
281+ 		switch  r .URL .Path  {
282+ 		case  "/v2/" :
283+ 			assert .Equal (t , "/v2/" , r .URL .Path )
284+ 			w .Header ().Set ("Www-Authenticate" , fmt .Sprintf (`Bearer realm="%s",service="%s"` , mockedServerURL (), mockedServerURL ()[8 :]))
285+ 			w .WriteHeader (http .StatusUnauthorized )
286+ 		case  "/oauth2/exchange" :
287+ 			assert .Equal (t , "/oauth2/exchange" , r .URL .Path )
288+ 			var  response  string 
289+ 			switch  callCount  {
290+ 			case  0 :
291+ 				response  =  fmt .Sprintf (`{"refresh_token": "%s"}` , accessToken1 )
292+ 			case  1 :
293+ 				response  =  fmt .Sprintf (`{"refresh_token": "%s"}` , accessToken2 )
294+ 			default :
295+ 				response  =  `{"refresh_token": "defaultToken"}` 
296+ 			}
297+ 			callCount ++ 
298+ 			w .WriteHeader (http .StatusOK )
299+ 			_ , err  :=  w .Write ([]byte (response ))
300+ 			require .NoError (t , err )
301+ 		default :
302+ 			http .NotFound (w , r )
303+ 		}
304+ 	}))
305+ 	defer  mockServer .Close ()
306+ 	mockServerURL  =  mockServer .URL 
307+ 
308+ 	workloadIdentityMock  :=  new (mocks.TokenProvider )
309+ 	workloadIdentityMock .On ("GetToken" , "https://management.core.windows.net/.default" ).Return (& workloadidentity.Token {AccessToken : "accessToken" }, nil )
310+ 	creds  :=  NewAzureWorkloadIdentityCreds (mockServer .URL [8 :], "" , nil , nil , true , workloadIdentityMock )
311+ 
312+ 	refreshToken , err  :=  creds .GetAccessToken ()
313+ 	require .NoError (t , err )
314+ 	assert .Equal (t , accessToken1 , refreshToken )
315+ 
316+ 	time .Sleep (5  *  time .Second ) // Wait for the token to expire 
317+ 
318+ 	refreshToken , err  =  creds .GetAccessToken ()
319+ 	require .NoError (t , err )
320+ 	assert .Equal (t , accessToken2 , refreshToken )
321+ }
322+ 
323+ func  TestGetAccessToken_ReuseTokenIfExistingIsNotExpired (t  * testing.T ) {
324+ 	resetAzureTokenCache ()
325+ 	accessToken1 , _  :=  generateMockJWT (time .Now ().Add (6  *  time .Minute ))
326+ 	accessToken2 , _  :=  generateMockJWT (time .Now ().Add (1  *  time .Minute ))
327+ 
328+ 	mockServerURL  :=  "" 
329+ 	mockedServerURL  :=  func () string  {
330+ 		return  mockServerURL 
331+ 	}
332+ 
333+ 	callCount  :=  0 
334+ 	mockServer  :=  httptest .NewTLSServer (http .HandlerFunc (func (w  http.ResponseWriter , r  * http.Request ) {
335+ 		switch  r .URL .Path  {
336+ 		case  "/v2/" :
337+ 			assert .Equal (t , "/v2/" , r .URL .Path )
338+ 			w .Header ().Set ("Www-Authenticate" , fmt .Sprintf (`Bearer realm="%s",service="%s"` , mockedServerURL (), mockedServerURL ()[8 :]))
339+ 			w .WriteHeader (http .StatusUnauthorized )
340+ 		case  "/oauth2/exchange" :
341+ 			assert .Equal (t , "/oauth2/exchange" , r .URL .Path )
342+ 			var  response  string 
343+ 			switch  callCount  {
344+ 			case  0 :
345+ 				response  =  fmt .Sprintf (`{"refresh_token": "%s"}` , accessToken1 )
346+ 			case  1 :
347+ 				response  =  fmt .Sprintf (`{"refresh_token": "%s"}` , accessToken2 )
348+ 			default :
349+ 				response  =  `{"refresh_token": "defaultToken"}` 
350+ 			}
351+ 			callCount ++ 
352+ 			w .WriteHeader (http .StatusOK )
353+ 			_ , err  :=  w .Write ([]byte (response ))
354+ 			require .NoError (t , err )
355+ 		default :
356+ 			http .NotFound (w , r )
357+ 		}
358+ 	}))
359+ 	defer  mockServer .Close ()
360+ 	mockServerURL  =  mockServer .URL 
361+ 
362+ 	workloadIdentityMock  :=  new (mocks.TokenProvider )
363+ 	workloadIdentityMock .On ("GetToken" , "https://management.core.windows.net/.default" ).Return (& workloadidentity.Token {AccessToken : "accessToken" }, nil )
364+ 	creds  :=  NewAzureWorkloadIdentityCreds (mockServer .URL [8 :], "" , nil , nil , true , workloadIdentityMock )
365+ 
366+ 	refreshToken , err  :=  creds .GetAccessToken ()
367+ 	require .NoError (t , err )
368+ 	assert .Equal (t , accessToken1 , refreshToken )
369+ 
370+ 	time .Sleep (5  *  time .Second ) // Wait for the token to expire 
371+ 
372+ 	refreshToken , err  =  creds .GetAccessToken ()
373+ 	require .NoError (t , err )
374+ 	assert .Equal (t , accessToken1 , refreshToken )
375+ }
376+ 
377+ func  resetAzureTokenCache () {
378+ 	azureTokenCache  =  gocache .New (gocache .NoExpiration , 0 )
379+ }
0 commit comments