Skip to content

Commit 78ed5f3

Browse files
Retry on transient failures during AWS IAM auth login attempts (#8727)
* use retryer for failed aws auth attempts * fixes from testing
1 parent 2de996b commit 78ed5f3

File tree

1 file changed

+26
-3
lines changed

1 file changed

+26
-3
lines changed

builtin/credential/aws/path_login.go

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@ import (
1616
"time"
1717

1818
"github.com/aws/aws-sdk-go/aws"
19+
awsClient "github.com/aws/aws-sdk-go/aws/client"
1920
"github.com/aws/aws-sdk-go/service/ec2"
2021
"github.com/aws/aws-sdk-go/service/iam"
2122
"github.com/fullsailor/pkcs7"
2223
"github.com/hashicorp/errwrap"
2324
cleanhttp "github.com/hashicorp/go-cleanhttp"
25+
"github.com/hashicorp/go-retryablehttp"
2426
uuid "github.com/hashicorp/go-uuid"
2527
"github.com/hashicorp/vault/sdk/framework"
2628
"github.com/hashicorp/vault/sdk/helper/awsutil"
@@ -35,6 +37,10 @@ const (
3537
iamAuthType = "iam"
3638
ec2AuthType = "ec2"
3739
ec2EntityType = "ec2_instance"
40+
41+
// Retry configuration
42+
retryWaitMin = 500 * time.Millisecond
43+
retryWaitMax = 30 * time.Second
3844
)
3945

4046
func (b *backend) pathLogin() *framework.Path {
@@ -1198,6 +1204,7 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request,
11981204

11991205
endpoint := "https://sts.amazonaws.com"
12001206

1207+
maxRetries := awsClient.DefaultRetryerMaxNumRetries
12011208
if config != nil {
12021209
if config.IAMServerIdHeaderValue != "" {
12031210
err = validateVaultHeaderValue(headers, parsedUrl, config.IAMServerIdHeaderValue)
@@ -1208,9 +1215,12 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request,
12081215
if config.STSEndpoint != "" {
12091216
endpoint = config.STSEndpoint
12101217
}
1218+
if config.MaxRetries >= 0 {
1219+
maxRetries = config.MaxRetries
1220+
}
12111221
}
12121222

1213-
callerID, err := submitCallerIdentityRequest(method, endpoint, parsedUrl, body, headers)
1223+
callerID, err := submitCallerIdentityRequest(ctx, maxRetries, method, endpoint, parsedUrl, body, headers)
12141224
if err != nil {
12151225
return logical.ErrorResponse(fmt.Sprintf("error making upstream request: %v", err)), nil
12161226
}
@@ -1555,18 +1565,31 @@ func parseGetCallerIdentityResponse(response string) (GetCallerIdentityResponse,
15551565
return result, err
15561566
}
15571567

1558-
func submitCallerIdentityRequest(method, endpoint string, parsedUrl *url.URL, body string, headers http.Header) (*GetCallerIdentityResult, error) {
1568+
func submitCallerIdentityRequest(ctx context.Context, maxRetries int, method, endpoint string, parsedUrl *url.URL, body string, headers http.Header) (*GetCallerIdentityResult, error) {
15591569
// NOTE: We need to ensure we're calling STS, instead of acting as an unintended network proxy
15601570
// The protection against this is that this method will only call the endpoint specified in the
15611571
// client config (defaulting to sts.amazonaws.com), so it would require a Vault admin to override
15621572
// the endpoint to talk to alternate web addresses
15631573
request := buildHttpRequest(method, endpoint, parsedUrl, body, headers)
1574+
retryableReq, err := retryablehttp.FromRequest(request)
1575+
if err != nil {
1576+
return nil, err
1577+
}
1578+
retryableReq = retryableReq.WithContext(ctx)
15641579
client := cleanhttp.DefaultClient()
15651580
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
15661581
return http.ErrUseLastResponse
15671582
}
1583+
retryingClient := &retryablehttp.Client{
1584+
HTTPClient: client,
1585+
RetryWaitMin: retryWaitMin,
1586+
RetryWaitMax: retryWaitMax,
1587+
RetryMax: maxRetries,
1588+
CheckRetry: retryablehttp.DefaultRetryPolicy,
1589+
Backoff: retryablehttp.DefaultBackoff,
1590+
}
15681591

1569-
response, err := client.Do(request)
1592+
response, err := retryingClient.Do(retryableReq)
15701593
if err != nil {
15711594
return nil, errwrap.Wrapf("error making request: {{err}}", err)
15721595
}

0 commit comments

Comments
 (0)