@@ -16,11 +16,13 @@ import (
1616 "time"
1717
1818 "github.com/aws/aws-sdk-go/aws"
19+ awsClient "github.com/aws/aws-sdk-go/aws/client"
1920 "github.com/aws/aws-sdk-go/service/ec2"
2021 "github.com/aws/aws-sdk-go/service/iam"
2122 "github.com/fullsailor/pkcs7"
2223 "github.com/hashicorp/errwrap"
2324 cleanhttp "github.com/hashicorp/go-cleanhttp"
25+ "github.com/hashicorp/go-retryablehttp"
2426 uuid "github.com/hashicorp/go-uuid"
2527 "github.com/hashicorp/vault/sdk/framework"
2628 "github.com/hashicorp/vault/sdk/helper/awsutil"
@@ -35,6 +37,10 @@ const (
3537 iamAuthType = "iam"
3638 ec2AuthType = "ec2"
3739 ec2EntityType = "ec2_instance"
40+
41+ // Retry configuration
42+ retryWaitMin = 500 * time .Millisecond
43+ retryWaitMax = 30 * time .Second
3844)
3945
4046func (b * backend ) pathLogin () * framework.Path {
@@ -1198,6 +1204,7 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request,
11981204
11991205 endpoint := "https://sts.amazonaws.com"
12001206
1207+ maxRetries := awsClient .DefaultRetryerMaxNumRetries
12011208 if config != nil {
12021209 if config .IAMServerIdHeaderValue != "" {
12031210 err = validateVaultHeaderValue (headers , parsedUrl , config .IAMServerIdHeaderValue )
@@ -1208,9 +1215,12 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request,
12081215 if config .STSEndpoint != "" {
12091216 endpoint = config .STSEndpoint
12101217 }
1218+ if config .MaxRetries >= 0 {
1219+ maxRetries = config .MaxRetries
1220+ }
12111221 }
12121222
1213- callerID , err := submitCallerIdentityRequest (method , endpoint , parsedUrl , body , headers )
1223+ callerID , err := submitCallerIdentityRequest (ctx , maxRetries , method , endpoint , parsedUrl , body , headers )
12141224 if err != nil {
12151225 return logical .ErrorResponse (fmt .Sprintf ("error making upstream request: %v" , err )), nil
12161226 }
@@ -1555,18 +1565,31 @@ func parseGetCallerIdentityResponse(response string) (GetCallerIdentityResponse,
15551565 return result , err
15561566}
15571567
1558- func submitCallerIdentityRequest (method , endpoint string , parsedUrl * url.URL , body string , headers http.Header ) (* GetCallerIdentityResult , error ) {
1568+ func submitCallerIdentityRequest (ctx context. Context , maxRetries int , method , endpoint string , parsedUrl * url.URL , body string , headers http.Header ) (* GetCallerIdentityResult , error ) {
15591569 // NOTE: We need to ensure we're calling STS, instead of acting as an unintended network proxy
15601570 // The protection against this is that this method will only call the endpoint specified in the
15611571 // client config (defaulting to sts.amazonaws.com), so it would require a Vault admin to override
15621572 // the endpoint to talk to alternate web addresses
15631573 request := buildHttpRequest (method , endpoint , parsedUrl , body , headers )
1574+ retryableReq , err := retryablehttp .FromRequest (request )
1575+ if err != nil {
1576+ return nil , err
1577+ }
1578+ retryableReq = retryableReq .WithContext (ctx )
15641579 client := cleanhttp .DefaultClient ()
15651580 client .CheckRedirect = func (req * http.Request , via []* http.Request ) error {
15661581 return http .ErrUseLastResponse
15671582 }
1583+ retryingClient := & retryablehttp.Client {
1584+ HTTPClient : client ,
1585+ RetryWaitMin : retryWaitMin ,
1586+ RetryWaitMax : retryWaitMax ,
1587+ RetryMax : maxRetries ,
1588+ CheckRetry : retryablehttp .DefaultRetryPolicy ,
1589+ Backoff : retryablehttp .DefaultBackoff ,
1590+ }
15681591
1569- response , err := client .Do (request )
1592+ response , err := retryingClient .Do (retryableReq )
15701593 if err != nil {
15711594 return nil , errwrap .Wrapf ("error making request: {{err}}" , err )
15721595 }
0 commit comments