diff --git a/aws/credentials/processcreds/provider.go b/aws/credentials/processcreds/provider.go new file mode 100644 index 00000000000..f56e70c7392 --- /dev/null +++ b/aws/credentials/processcreds/provider.go @@ -0,0 +1,425 @@ +/* +Package processcreds is a credential Provider to retrieve `credential_process` +credentials. + +WARNING: The following describes a method of sourcing credentials from an external +process. This can potentially be dangerous, so proceed with caution. Other +credential providers should be preferred if at all possible. If using this +option, you should make sure that the config file is as locked down as possible +using security best practices for your operating system. + +You can use credentials from a `credential_process` in a variety of ways. + +One way is to setup your shared config file, located in the default +location, with the `credential_process` key and the command you want to be +called. You also need to set the AWS_SDK_LOAD_CONFIG environment variable +(e.g., `export AWS_SDK_LOAD_CONFIG=1`) to use the shared config file. + + [default] + credential_process = /command/to/call + +Creating a new session will use the credential process to retrieve credentials. +NOTE: If there are credentials in the profile you are using, the credential +process will not be used. + + // Initialize a session to load credentials. + sess, _ := session.NewSession(&aws.Config{ + Region: aws.String("us-east-1")}, + ) + + // Create S3 service client to use the credentials. + svc := s3.New(sess) + +Another way to use the `credential_process` method is by using +`credentials.NewCredentials()` and providing a command to be executed to +retrieve credentials: + + // Create credentials using the ProcessProvider. + creds := processcreds.NewCredentials("/path/to/command") + + // Create service client value configured for credentials. + svc := s3.New(sess, &aws.Config{Credentials: creds}) + +You can set a non-default timeout for the `credential_process` with another +constructor, `credentials.NewCredentialsTimeout()`, providing the timeout. To +set a one minute timeout: + + // Create credentials using the ProcessProvider. + creds := processcreds.NewCredentialsTimeout( + "/path/to/command", + time.Duration(500) * time.Millisecond) + +If you need more control, you can set any configurable options in the +credentials using one or more option functions. For example, you can set a two +minute timeout, a credential duration of 60 minutes, and a maximum stdout +buffer size of 2k. + + creds := processcreds.NewCredentials( + "/path/to/command", + func(opt *ProcessProvider) { + opt.Timeout = time.Duration(2) * time.Minute + opt.Duration = time.Duration(60) * time.Minute + opt.MaxBufSize = 2048 + }) + +You can also use your own `exec.Cmd`: + + // Create an exec.Cmd + myCommand := exec.Command("/path/to/command") + + // Create credentials using your exec.Cmd and custom timeout + creds := processcreds.NewCredentialsCommand( + myCommand, + func(opt *processcreds.ProcessProvider) { + opt.Timeout = time.Duration(1) * time.Second + }) +*/ +package processcreds + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "os" + "os/exec" + "runtime" + "strings" + "time" + + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/credentials" +) + +const ( + // ProviderName is the name this credentials provider will label any + // returned credentials Value with. + ProviderName = `ProcessProvider` + + // ErrCodeProcessProviderParse error parsing process output + ErrCodeProcessProviderParse = "ProcessProviderParseError" + + // ErrCodeProcessProviderVersion version error in output + ErrCodeProcessProviderVersion = "ProcessProviderVersionError" + + // ErrCodeProcessProviderRequired required attribute missing in output + ErrCodeProcessProviderRequired = "ProcessProviderRequiredError" + + // ErrCodeProcessProviderExecution execution of command failed + ErrCodeProcessProviderExecution = "ProcessProviderExecutionError" + + // errMsgProcessProviderTimeout process took longer than allowed + errMsgProcessProviderTimeout = "credential process timed out" + + // errMsgProcessProviderProcess process error + errMsgProcessProviderProcess = "error in credential_process" + + // errMsgProcessProviderParse problem parsing output + errMsgProcessProviderParse = "parse failed of credential_process output" + + // errMsgProcessProviderVersion version error in output + errMsgProcessProviderVersion = "wrong version in process output (not 1)" + + // errMsgProcessProviderMissKey missing access key id in output + errMsgProcessProviderMissKey = "missing AccessKeyId in process output" + + // errMsgProcessProviderMissSecret missing secret acess key in output + errMsgProcessProviderMissSecret = "missing SecretAccessKey in process output" + + // errMsgProcessProviderPrepareCmd prepare of command failed + errMsgProcessProviderPrepareCmd = "failed to prepare command" + + // errMsgProcessProviderEmptyCmd command must not be empty + errMsgProcessProviderEmptyCmd = "command must not be empty" + + // errMsgProcessProviderPipe failed to initialize pipe + errMsgProcessProviderPipe = "failed to initialize pipe" + + // DefaultDuration is the default amount of time in minutes that the + // credentials will be valid for. + DefaultDuration = time.Duration(15) * time.Minute + + // DefaultBufSize limits buffer size from growing to an enormous + // amount due to a faulty process. + DefaultBufSize = 512 + + // DefaultTimeout default limit on time a process can run. + DefaultTimeout = time.Duration(1) * time.Minute +) + +// ProcessProvider satisfies the credentials.Provider interface, and is a +// client to retrieve credentials from a process. +type ProcessProvider struct { + staticCreds bool + credentials.Expiry + originalCommand []string + + // Expiry duration of the credentials. Defaults to 15 minutes if not set. + Duration time.Duration + + // ExpiryWindow will allow the credentials to trigger refreshing prior to + // the credentials actually expiring. This is beneficial so race conditions + // with expiring credentials do not cause request to fail unexpectedly + // due to ExpiredTokenException exceptions. + // + // So a ExpiryWindow of 10s would cause calls to IsExpired() to return true + // 10 seconds before the credentials are actually expired. + // + // If ExpiryWindow is 0 or less it will be ignored. + ExpiryWindow time.Duration + + // A string representing an os command that should return a JSON with + // credential information. + command *exec.Cmd + + // MaxBufSize limits memory usage from growing to an enormous + // amount due to a faulty process. + MaxBufSize int + + // Timeout limits the time a process can run. + Timeout time.Duration +} + +// NewCredentials returns a pointer to a new Credentials object wrapping the +// ProcessProvider. The credentials will expire every 15 minutes by default. +func NewCredentials(command string, options ...func(*ProcessProvider)) *credentials.Credentials { + p := &ProcessProvider{ + command: exec.Command(command), + Duration: DefaultDuration, + Timeout: DefaultTimeout, + MaxBufSize: DefaultBufSize, + } + + for _, option := range options { + option(p) + } + + return credentials.NewCredentials(p) +} + +// NewCredentialsTimeout returns a pointer to a new Credentials object with +// the specified command and timeout, and default duration and max buffer size. +func NewCredentialsTimeout(command string, timeout time.Duration) *credentials.Credentials { + p := NewCredentials(command, func(opt *ProcessProvider) { + opt.Timeout = timeout + }) + + return p +} + +// NewCredentialsCommand returns a pointer to a new Credentials object with +// the specified command, and default timeout, duration and max buffer size. +func NewCredentialsCommand(command *exec.Cmd, options ...func(*ProcessProvider)) *credentials.Credentials { + p := &ProcessProvider{ + command: command, + Duration: DefaultDuration, + Timeout: DefaultTimeout, + MaxBufSize: DefaultBufSize, + } + + for _, option := range options { + option(p) + } + + return credentials.NewCredentials(p) +} + +type credentialProcessResponse struct { + Version int + AccessKeyID string `json:"AccessKeyId"` + SecretAccessKey string + SessionToken string + Expiration *time.Time +} + +// Retrieve executes the 'credential_process' and returns the credentials. +func (p *ProcessProvider) Retrieve() (credentials.Value, error) { + out, err := p.executeCredentialProcess() + if err != nil { + return credentials.Value{ProviderName: ProviderName}, err + } + + // Serialize and validate response + resp := &credentialProcessResponse{} + if err = json.Unmarshal(out, resp); err != nil { + return credentials.Value{ProviderName: ProviderName}, awserr.New( + ErrCodeProcessProviderParse, + fmt.Sprintf("%s: %s", errMsgProcessProviderParse, string(out)), + err) + } + + if resp.Version != 1 { + return credentials.Value{ProviderName: ProviderName}, awserr.New( + ErrCodeProcessProviderVersion, + errMsgProcessProviderVersion, + nil) + } + + if len(resp.AccessKeyID) == 0 { + return credentials.Value{ProviderName: ProviderName}, awserr.New( + ErrCodeProcessProviderRequired, + errMsgProcessProviderMissKey, + nil) + } + + if len(resp.SecretAccessKey) == 0 { + return credentials.Value{ProviderName: ProviderName}, awserr.New( + ErrCodeProcessProviderRequired, + errMsgProcessProviderMissSecret, + nil) + } + + // Handle expiration + p.staticCreds = resp.Expiration == nil + if resp.Expiration != nil { + p.SetExpiration(*resp.Expiration, p.ExpiryWindow) + } + + return credentials.Value{ + ProviderName: ProviderName, + AccessKeyID: resp.AccessKeyID, + SecretAccessKey: resp.SecretAccessKey, + SessionToken: resp.SessionToken, + }, nil +} + +// IsExpired returns true if the credentials retrieved are expired, or not yet +// retrieved. +func (p *ProcessProvider) IsExpired() bool { + if p.staticCreds { + return false + } + return p.Expiry.IsExpired() +} + +// prepareCommand prepares the command to be executed. +func (p *ProcessProvider) prepareCommand() error { + + var cmdArgs []string + if runtime.GOOS == "windows" { + cmdArgs = []string{"cmd.exe", "/C"} + } else { + cmdArgs = []string{"sh", "-c"} + } + + if len(p.originalCommand) == 0 { + p.originalCommand = make([]string, len(p.command.Args)) + copy(p.originalCommand, p.command.Args) + + // check for empty command because it succeeds + if len(strings.TrimSpace(p.originalCommand[0])) < 1 { + return awserr.New( + ErrCodeProcessProviderExecution, + fmt.Sprintf( + "%s: %s", + errMsgProcessProviderPrepareCmd, + errMsgProcessProviderEmptyCmd), + nil) + } + } + + cmdArgs = append(cmdArgs, p.originalCommand...) + p.command = exec.Command(cmdArgs[0], cmdArgs[1:]...) + p.command.Env = os.Environ() + + return nil +} + +// executeCredentialProcess starts the credential process on the OS and +// returns the results or an error. +func (p *ProcessProvider) executeCredentialProcess() ([]byte, error) { + + if err := p.prepareCommand(); err != nil { + return nil, err + } + + // Setup the pipes + outReadPipe, outWritePipe, err := os.Pipe() + if err != nil { + return nil, awserr.New( + ErrCodeProcessProviderExecution, + errMsgProcessProviderPipe, + err) + } + + p.command.Stderr = os.Stderr // display stderr on console for MFA + p.command.Stdout = outWritePipe // get creds json on process's stdout + p.command.Stdin = os.Stdin // enable stdin for MFA + + output := bytes.NewBuffer(make([]byte, 0, p.MaxBufSize)) + + stdoutCh := make(chan error, 1) + go readInput( + io.LimitReader(outReadPipe, int64(p.MaxBufSize)), + output, + stdoutCh) + + execCh := make(chan error, 1) + go executeCommand(*p.command, execCh) + + finished := false + var errors []error + for !finished { + select { + case readError := <-stdoutCh: + errors = appendError(errors, readError) + finished = true + case execError := <-execCh: + err := outWritePipe.Close() + errors = appendError(errors, err) + errors = appendError(errors, execError) + if errors != nil { + return output.Bytes(), awserr.NewBatchError( + ErrCodeProcessProviderExecution, + errMsgProcessProviderProcess, + errors) + } + case <-time.After(p.Timeout): + finished = true + return output.Bytes(), awserr.NewBatchError( + ErrCodeProcessProviderExecution, + errMsgProcessProviderTimeout, + errors) // errors can be nil + } + } + + out := output.Bytes() + + if runtime.GOOS == "windows" { + // windows adds slashes to quotes + out = []byte(strings.Replace(string(out), `\"`, `"`, -1)) + } + + return out, nil +} + +// appendError conveniently checks for nil before appending slice +func appendError(errors []error, err error) []error { + if err != nil { + return append(errors, err) + } + return errors +} + +func executeCommand(cmd exec.Cmd, exec chan error) { + // Start the command + err := cmd.Start() + if err == nil { + err = cmd.Wait() + } + + exec <- err +} + +func readInput(r io.Reader, w io.Writer, read chan error) { + tee := io.TeeReader(r, w) + + _, err := ioutil.ReadAll(tee) + + if err == io.EOF { + err = nil + } + + read <- err // will only arrive here when write end of pipe is closed +} diff --git a/aws/credentials/processcreds/provider_test.go b/aws/credentials/processcreds/provider_test.go new file mode 100644 index 00000000000..fd3253d1cfe --- /dev/null +++ b/aws/credentials/processcreds/provider_test.go @@ -0,0 +1,561 @@ +package processcreds_test + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "os" + "os/exec" + "runtime" + "strings" + "testing" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/credentials/processcreds" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/awstesting" +) + +func TestProcessProviderFromSessionCfg(t *testing.T) { + oldEnv := preserveImportantStashEnv() + defer awstesting.PopEnv(oldEnv) + + os.Setenv("AWS_SDK_LOAD_CONFIG", "1") + if runtime.GOOS == "windows" { + os.Setenv("AWS_CONFIG_FILE", "testdata\\shconfig_win.ini") + } else { + os.Setenv("AWS_CONFIG_FILE", "testdata/shconfig.ini") + } + + sess, err := session.NewSession(&aws.Config{ + Region: aws.String("region")}, + ) + + if err != nil { + t.Errorf("error getting session: %v", err) + } + + creds, err := sess.Config.Credentials.Get() + if err != nil { + t.Errorf("error getting credentials: %v", err) + } + + if e, a := "accessKey", creds.AccessKeyID; e != a { + t.Errorf("expected %v, got %v", e, a) + } + + if e, a := "secret", creds.SecretAccessKey; e != a { + t.Errorf("expected %v, got %v", e, a) + } + + if e, a := "tokenDefault", creds.SessionToken; e != a { + t.Errorf("expected %v, got %v", e, a) + } + +} + +func TestProcessProviderFromSessionWithProfileCfg(t *testing.T) { + oldEnv := preserveImportantStashEnv() + defer awstesting.PopEnv(oldEnv) + + os.Setenv("AWS_SDK_LOAD_CONFIG", "1") + os.Setenv("AWS_PROFILE", "non_expire") + if runtime.GOOS == "windows" { + os.Setenv("AWS_CONFIG_FILE", "testdata\\shconfig_win.ini") + } else { + os.Setenv("AWS_CONFIG_FILE", "testdata/shconfig.ini") + } + + sess, err := session.NewSession(&aws.Config{ + Region: aws.String("region")}, + ) + + if err != nil { + t.Errorf("error getting session: %v", err) + } + + creds, err := sess.Config.Credentials.Get() + if err != nil { + t.Errorf("error getting credentials: %v", err) + } + + if e, a := "nonDefaultToken", creds.SessionToken; e != a { + t.Errorf("expected %v, got %v", e, a) + } + +} + +func TestProcessProviderNotFromCredProcCfg(t *testing.T) { + oldEnv := preserveImportantStashEnv() + defer awstesting.PopEnv(oldEnv) + + os.Setenv("AWS_SDK_LOAD_CONFIG", "1") + os.Setenv("AWS_PROFILE", "not_alone") + if runtime.GOOS == "windows" { + os.Setenv("AWS_CONFIG_FILE", "testdata\\shconfig_win.ini") + } else { + os.Setenv("AWS_CONFIG_FILE", "testdata/shconfig.ini") + } + + sess, err := session.NewSession(&aws.Config{ + Region: aws.String("region")}, + ) + + if err != nil { + t.Errorf("error getting session: %v", err) + } + + creds, err := sess.Config.Credentials.Get() + if err != nil { + t.Errorf("error getting credentials: %v", err) + } + + if e, a := "notFromCredProcAccess", creds.AccessKeyID; e != a { + t.Errorf("expected %v, got %v", e, a) + } + + if e, a := "notFromCredProcSecret", creds.SecretAccessKey; e != a { + t.Errorf("expected %v, got %v", e, a) + } + +} + +func TestProcessProviderFromSessionCrd(t *testing.T) { + oldEnv := preserveImportantStashEnv() + defer awstesting.PopEnv(oldEnv) + + if runtime.GOOS == "windows" { + os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "testdata\\shcred_win.ini") + } else { + os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "testdata/shcred.ini") + } + + sess, err := session.NewSession(&aws.Config{ + Region: aws.String("region")}, + ) + + if err != nil { + t.Errorf("error getting session: %v", err) + } + + creds, err := sess.Config.Credentials.Get() + if err != nil { + t.Errorf("error getting credentials: %v", err) + } + + if e, a := "accessKey", creds.AccessKeyID; e != a { + t.Errorf("expected %v, got %v", e, a) + } + + if e, a := "secret", creds.SecretAccessKey; e != a { + t.Errorf("expected %v, got %v", e, a) + } + + if e, a := "tokenDefault", creds.SessionToken; e != a { + t.Errorf("expected %v, got %v", e, a) + } + +} + +func TestProcessProviderFromSessionWithProfileCrd(t *testing.T) { + oldEnv := preserveImportantStashEnv() + defer awstesting.PopEnv(oldEnv) + + os.Setenv("AWS_PROFILE", "non_expire") + if runtime.GOOS == "windows" { + os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "testdata\\shcred_win.ini") + } else { + os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "testdata/shcred.ini") + } + + sess, err := session.NewSession(&aws.Config{ + Region: aws.String("region")}, + ) + + if err != nil { + t.Errorf("error getting session: %v", err) + } + + creds, err := sess.Config.Credentials.Get() + if err != nil { + t.Errorf("error getting credentials: %v", err) + } + + if e, a := "nonDefaultToken", creds.SessionToken; e != a { + t.Errorf("expected %v, got %v", e, a) + } + +} + +func TestProcessProviderNotFromCredProcCrd(t *testing.T) { + oldEnv := preserveImportantStashEnv() + defer awstesting.PopEnv(oldEnv) + + os.Setenv("AWS_PROFILE", "not_alone") + if runtime.GOOS == "windows" { + os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "testdata\\shcred_win.ini") + } else { + os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "testdata/shcred.ini") + } + + sess, err := session.NewSession(&aws.Config{ + Region: aws.String("region")}, + ) + + if err != nil { + t.Errorf("error getting session: %v", err) + } + + creds, err := sess.Config.Credentials.Get() + if err != nil { + t.Errorf("error getting credentials: %v", err) + } + + if e, a := "notFromCredProcAccess", creds.AccessKeyID; e != a { + t.Errorf("expected %v, got %v", e, a) + } + + if e, a := "notFromCredProcSecret", creds.SecretAccessKey; e != a { + t.Errorf("expected %v, got %v", e, a) + } + +} + +func TestProcessProviderBadCommand(t *testing.T) { + oldEnv := preserveImportantStashEnv() + defer awstesting.PopEnv(oldEnv) + + creds := processcreds.NewCredentials("/bad/process") + _, err := creds.Get() + if err.(awserr.Error).Code() != processcreds.ErrCodeProcessProviderExecution { + t.Errorf("expected %v, got %v", processcreds.ErrCodeProcessProviderExecution, err) + } +} + +func TestProcessProviderMoreEmptyCommands(t *testing.T) { + oldEnv := preserveImportantStashEnv() + defer awstesting.PopEnv(oldEnv) + + creds := processcreds.NewCredentials("") + _, err := creds.Get() + if err.(awserr.Error).Code() != processcreds.ErrCodeProcessProviderExecution { + t.Errorf("expected %v, got %v", processcreds.ErrCodeProcessProviderExecution, err) + } + +} + +func TestProcessProviderExpectErrors(t *testing.T) { + oldEnv := preserveImportantStashEnv() + defer awstesting.PopEnv(oldEnv) + + creds := processcreds.NewCredentials( + fmt.Sprintf( + "%s %s", + getOSCat(), + strings.Join( + []string{"testdata", "malformed.json"}, + string(os.PathSeparator)))) + _, err := creds.Get() + if err.(awserr.Error).Code() != processcreds.ErrCodeProcessProviderParse { + t.Errorf("expected %v, got %v", processcreds.ErrCodeProcessProviderParse, err) + } + + creds = processcreds.NewCredentials( + fmt.Sprintf("%s %s", + getOSCat(), + strings.Join( + []string{"testdata", "wrongversion.json"}, + string(os.PathSeparator)))) + _, err = creds.Get() + if err.(awserr.Error).Code() != processcreds.ErrCodeProcessProviderVersion { + t.Errorf("expected %v, got %v", processcreds.ErrCodeProcessProviderVersion, err) + } + + creds = processcreds.NewCredentials( + fmt.Sprintf( + "%s %s", + getOSCat(), + strings.Join( + []string{"testdata", "missingkey.json"}, + string(os.PathSeparator)))) + _, err = creds.Get() + if err.(awserr.Error).Code() != processcreds.ErrCodeProcessProviderRequired { + t.Errorf("expected %v, got %v", processcreds.ErrCodeProcessProviderRequired, err) + } + + creds = processcreds.NewCredentials( + fmt.Sprintf( + "%s %s", + getOSCat(), + strings.Join( + []string{"testdata", "missingsecret.json"}, + string(os.PathSeparator)))) + _, err = creds.Get() + if err.(awserr.Error).Code() != processcreds.ErrCodeProcessProviderRequired { + t.Errorf("expected %v, got %v", processcreds.ErrCodeProcessProviderRequired, err) + } + +} + +func TestProcessProviderTimeout(t *testing.T) { + oldEnv := preserveImportantStashEnv() + defer awstesting.PopEnv(oldEnv) + + command := "/bin/sleep 2" + if runtime.GOOS == "windows" { + // "timeout" command does not work due to pipe redirection + command = "ping -n 2 127.0.0.1>nul" + } + + creds := processcreds.NewCredentialsTimeout( + command, + time.Duration(1)*time.Second) + if _, err := creds.Get(); err == nil || err.(awserr.Error).Code() != processcreds.ErrCodeProcessProviderExecution || err.(awserr.Error).Message() != "credential process timed out" { + t.Errorf("expected %v, got %v", processcreds.ErrCodeProcessProviderExecution, err) + } + +} + +type credentialTest struct { + Version int + AccessKeyID string `json:"AccessKeyId"` + SecretAccessKey string + Expiration string +} + +func TestProcessProviderStatic(t *testing.T) { + oldEnv := preserveImportantStashEnv() + defer awstesting.PopEnv(oldEnv) + + // static + creds := processcreds.NewCredentials( + fmt.Sprintf( + "%s %s", + getOSCat(), + strings.Join( + []string{"testdata", "static.json"}, + string(os.PathSeparator)))) + _, err := creds.Get() + if err != nil { + t.Errorf("expected %v, got %v", "no error", err) + } + if creds.IsExpired() { + t.Errorf("expected %v, got %v", "static credentials/not expired", "expired") + } + +} + +func TestProcessProviderNotExpired(t *testing.T) { + oldEnv := preserveImportantStashEnv() + defer awstesting.PopEnv(oldEnv) + + // non-static, not expired + exp := &credentialTest{} + exp.Version = 1 + exp.AccessKeyID = "accesskey" + exp.SecretAccessKey = "secretkey" + exp.Expiration = time.Now().Add(1 * time.Hour).UTC().Format(time.RFC3339) + b, err := json.Marshal(exp) + if err != nil { + t.Errorf("expected %v, got %v", "no error", err) + } + + tmpFile := strings.Join( + []string{"testdata", "tmp_expiring.json"}, + string(os.PathSeparator)) + if err = ioutil.WriteFile(tmpFile, b, 0644); err != nil { + t.Errorf("expected %v, got %v", "no error", err) + } + defer func() { + if err = os.Remove(tmpFile); err != nil { + t.Errorf("expected %v, got %v", "no error", err) + } + }() + creds := processcreds.NewCredentials( + fmt.Sprintf("%s %s", getOSCat(), tmpFile)) + _, err = creds.Get() + if err != nil { + t.Errorf("expected %v, got %v", "no error", err) + } + if creds.IsExpired() { + t.Errorf("expected %v, got %v", "not expired", "expired") + } +} + +func TestProcessProviderExpired(t *testing.T) { + oldEnv := preserveImportantStashEnv() + defer awstesting.PopEnv(oldEnv) + + // non-static, expired + exp := &credentialTest{} + exp.Version = 1 + exp.AccessKeyID = "accesskey" + exp.SecretAccessKey = "secretkey" + exp.Expiration = time.Now().Add(-1 * time.Hour).UTC().Format(time.RFC3339) + b, err := json.Marshal(exp) + if err != nil { + t.Errorf("expected %v, got %v", "no error", err) + } + + tmpFile := strings.Join( + []string{"testdata", "tmp_expired.json"}, + string(os.PathSeparator)) + if err = ioutil.WriteFile(tmpFile, b, 0644); err != nil { + t.Errorf("expected %v, got %v", "no error", err) + } + defer func() { + if err = os.Remove(tmpFile); err != nil { + t.Errorf("expected %v, got %v", "no error", err) + } + }() + creds := processcreds.NewCredentials( + fmt.Sprintf("%s %s", getOSCat(), tmpFile)) + _, err = creds.Get() + if err != nil { + t.Errorf("expected %v, got %v", "no error", err) + } + if !creds.IsExpired() { + t.Errorf("expected %v, got %v", "expired", "not expired") + } +} + +func TestProcessProviderForceExpire(t *testing.T) { + oldEnv := preserveImportantStashEnv() + defer awstesting.PopEnv(oldEnv) + + // non-static, not expired + + // setup test credentials file + exp := &credentialTest{} + exp.Version = 1 + exp.AccessKeyID = "accesskey" + exp.SecretAccessKey = "secretkey" + exp.Expiration = time.Now().Add(1 * time.Hour).UTC().Format(time.RFC3339) + b, err := json.Marshal(exp) + if err != nil { + t.Errorf("expected %v, got %v", "no error", err) + } + tmpFile := strings.Join( + []string{"testdata", "tmp_force_expire.json"}, + string(os.PathSeparator)) + if err = ioutil.WriteFile(tmpFile, b, 0644); err != nil { + t.Errorf("expected %v, got %v", "no error", err) + } + defer func() { + if err = os.Remove(tmpFile); err != nil { + t.Errorf("expected %v, got %v", "no error", err) + } + }() + + // get credentials from file + creds := processcreds.NewCredentials( + fmt.Sprintf("%s %s", getOSCat(), tmpFile)) + if _, err = creds.Get(); err != nil { + t.Errorf("expected %v, got %v", "no error", err) + } + if creds.IsExpired() { + t.Errorf("expected %v, got %v", "not expired", "expired") + } + + // force expire creds + creds.Expire() + if !creds.IsExpired() { + t.Errorf("expected %v, got %v", "expired", "not expired") + } + + // renew creds + if _, err = creds.Get(); err != nil { + t.Errorf("expected %v, got %v", "no error", err) + } + if creds.IsExpired() { + t.Errorf("expected %v, got %v", "not expired", "expired") + } + +} + +func TestProcessProviderAltConstruct(t *testing.T) { + oldEnv := preserveImportantStashEnv() + defer awstesting.PopEnv(oldEnv) + + // constructing with exec.Cmd instead of string + myCommand := exec.Command( + fmt.Sprintf( + "%s %s", + getOSCat(), + strings.Join( + []string{"testdata", "static.json"}, + string(os.PathSeparator)))) + creds := processcreds.NewCredentialsCommand(myCommand, func(opt *processcreds.ProcessProvider) { + opt.Timeout = time.Duration(1) * time.Second + }) + _, err := creds.Get() + if err != nil { + t.Errorf("expected %v, got %v", "no error", err) + } + if creds.IsExpired() { + t.Errorf("expected %v, got %v", "static credentials/not expired", "expired") + } +} + +func BenchmarkProcessProvider(b *testing.B) { + oldEnv := preserveImportantStashEnv() + defer awstesting.PopEnv(oldEnv) + + creds := processcreds.NewCredentials( + fmt.Sprintf( + "%s %s", + getOSCat(), + strings.Join( + []string{"testdata", "static.json"}, + string(os.PathSeparator)))) + _, err := creds.Get() + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := creds.Get() + if err != nil { + b.Fatal(err) + } + } +} + +func preserveImportantStashEnv() []string { + envsToKeep := []string{"PATH"} + + if runtime.GOOS == "windows" { + envsToKeep = append(envsToKeep, "ComSpec") + envsToKeep = append(envsToKeep, "SYSTEM32") + } + + extraEnv := getEnvs(envsToKeep) + + oldEnv := awstesting.StashEnv() //clear env + + for key, val := range extraEnv { + os.Setenv(key, val) + } + + return oldEnv +} + +func getEnvs(envs []string) map[string]string { + extraEnvs := make(map[string]string) + for _, env := range envs { + if val, ok := os.LookupEnv(env); ok && len(val) > 0 { + extraEnvs[env] = val + } + } + return extraEnvs +} + +func getOSCat() string { + if runtime.GOOS == "windows" { + return "type" + } + return "cat" +} diff --git a/aws/credentials/processcreds/testdata/expired.json b/aws/credentials/processcreds/testdata/expired.json new file mode 100644 index 00000000000..00753a8d12b --- /dev/null +++ b/aws/credentials/processcreds/testdata/expired.json @@ -0,0 +1,7 @@ +{ + "Version": 1, + "AccessKeyId": "accessKey", + "SecretAccessKey": "secret", + "SessionToken": "tokenDefault", + "Expiration": "2000-01-01T00:00:00-00:00" +} diff --git a/aws/credentials/processcreds/testdata/malformed.json b/aws/credentials/processcreds/testdata/malformed.json new file mode 100644 index 00000000000..1e9652b423d --- /dev/null +++ b/aws/credentials/processcreds/testdata/malformed.json @@ -0,0 +1,2 @@ +{ + "Version": 1 diff --git a/aws/credentials/processcreds/testdata/missingkey.json b/aws/credentials/processcreds/testdata/missingkey.json new file mode 100644 index 00000000000..ea54b015536 --- /dev/null +++ b/aws/credentials/processcreds/testdata/missingkey.json @@ -0,0 +1,4 @@ +{ + "Version": 1, + "AccessKeyId": "accesskey" +} diff --git a/aws/credentials/processcreds/testdata/missingsecret.json b/aws/credentials/processcreds/testdata/missingsecret.json new file mode 100644 index 00000000000..c8740b13f4d --- /dev/null +++ b/aws/credentials/processcreds/testdata/missingsecret.json @@ -0,0 +1,4 @@ +{ + "Version": 1, + "SecretAccessKey": "secretkey" +} diff --git a/aws/credentials/processcreds/testdata/nonexpire.json b/aws/credentials/processcreds/testdata/nonexpire.json new file mode 100644 index 00000000000..5e567131a47 --- /dev/null +++ b/aws/credentials/processcreds/testdata/nonexpire.json @@ -0,0 +1,6 @@ +{ + "Version": 1, + "AccessKeyId": "accessKey", + "SecretAccessKey": "secret", + "SessionToken": "nonDefaultToken" +} diff --git a/aws/credentials/processcreds/testdata/shconfig.ini b/aws/credentials/processcreds/testdata/shconfig.ini new file mode 100644 index 00000000000..9c236946c2b --- /dev/null +++ b/aws/credentials/processcreds/testdata/shconfig.ini @@ -0,0 +1,10 @@ +[default] +credential_process = cat ./testdata/expired.json + +[profile non_expire] +credential_process = cat ./testdata/nonexpire.json + +[profile not_alone] +aws_access_key_id = notFromCredProcAccess +aws_secret_access_key = notFromCredProcSecret +credential_process = cat ./testdata/verybad.json diff --git a/aws/credentials/processcreds/testdata/shconfig_win.ini b/aws/credentials/processcreds/testdata/shconfig_win.ini new file mode 100644 index 00000000000..59318d88e54 --- /dev/null +++ b/aws/credentials/processcreds/testdata/shconfig_win.ini @@ -0,0 +1,10 @@ +[default] +credential_process = type .\testdata\expired.json + +[profile non_expire] +credential_process = type .\testdata\nonexpire.json + +[profile not_alone] +aws_access_key_id = notFromCredProcAccess +aws_secret_access_key = notFromCredProcSecret +credential_process = type .\testdata\verybad.json diff --git a/aws/credentials/processcreds/testdata/shcred.ini b/aws/credentials/processcreds/testdata/shcred.ini new file mode 100644 index 00000000000..81ca26ba960 --- /dev/null +++ b/aws/credentials/processcreds/testdata/shcred.ini @@ -0,0 +1,10 @@ +[default] +credential_process = cat ./testdata/expired.json + +[non_expire] +credential_process = cat ./testdata/nonexpire.json + +[not_alone] +aws_access_key_id = notFromCredProcAccess +aws_secret_access_key = notFromCredProcSecret +credential_process = cat ./testdata/verybad.json diff --git a/aws/credentials/processcreds/testdata/shcred_win.ini b/aws/credentials/processcreds/testdata/shcred_win.ini new file mode 100644 index 00000000000..ad4559c258d --- /dev/null +++ b/aws/credentials/processcreds/testdata/shcred_win.ini @@ -0,0 +1,10 @@ +[default] +credential_process = type .\testdata\expired.json + +[non_expire] +credential_process = type .\testdata\nonexpire.json + +[not_alone] +aws_access_key_id = notFromCredProcAccess +aws_secret_access_key = notFromCredProcSecret +credential_process = type .\testdata\verybad.json diff --git a/aws/credentials/processcreds/testdata/static.json b/aws/credentials/processcreds/testdata/static.json new file mode 100644 index 00000000000..9fddfa123fc --- /dev/null +++ b/aws/credentials/processcreds/testdata/static.json @@ -0,0 +1,5 @@ +{ + "Version":1, + "AccessKeyId":"accesskey", + "SecretAccessKey":"secretkey" +} diff --git a/aws/credentials/processcreds/testdata/verybad.json b/aws/credentials/processcreds/testdata/verybad.json new file mode 100644 index 00000000000..968883b8b80 --- /dev/null +++ b/aws/credentials/processcreds/testdata/verybad.json @@ -0,0 +1,5 @@ +{ + "Version":1, + "AccessKeyId":"veryBadAccessKeyID", + "SecretAccessKey":"veryBadSecretAccessKey" +} diff --git a/aws/credentials/processcreds/testdata/wrongversion.json b/aws/credentials/processcreds/testdata/wrongversion.json new file mode 100644 index 00000000000..a58ea78dcef --- /dev/null +++ b/aws/credentials/processcreds/testdata/wrongversion.json @@ -0,0 +1,3 @@ +{ + "Version": 2 +} diff --git a/aws/session/session.go b/aws/session/session.go index e7c156e8b12..9bdbafd65cc 100644 --- a/aws/session/session.go +++ b/aws/session/session.go @@ -14,6 +14,7 @@ import ( "github.com/aws/aws-sdk-go/aws/client" "github.com/aws/aws-sdk-go/aws/corehandlers" "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/credentials/processcreds" "github.com/aws/aws-sdk-go/aws/credentials/stscreds" "github.com/aws/aws-sdk-go/aws/csm" "github.com/aws/aws-sdk-go/aws/defaults" @@ -534,6 +535,10 @@ func mergeConfigSrcs(cfg, userCfg *aws.Config, envCfg envConfig, sharedCfg share cfg.Credentials = credentials.NewStaticCredentialsFromCreds( sharedCfg.Creds, ) + } else if len(sharedCfg.CredentialProcess) > 0 { + cfg.Credentials = processcreds.NewCredentials( + sharedCfg.CredentialProcess, + ) } else { // Fallback to default credentials provider, include mock errors // for the credential chain so user can identify why credentials diff --git a/aws/session/shared_config.go b/aws/session/shared_config.go index 427b8a4e997..7cb44021b3f 100644 --- a/aws/session/shared_config.go +++ b/aws/session/shared_config.go @@ -28,6 +28,8 @@ const ( // endpoint discovery group enableEndpointDiscoveryKey = `endpoint_discovery_enabled` // optional + // External Credential Process + credentialProcessKey = `credential_process` // DefaultSharedConfigProfile is the default profile to be used when // loading configuration from the config files if another profile name @@ -60,6 +62,9 @@ type sharedConfig struct { AssumeRole assumeRoleConfig AssumeRoleSource *sharedConfig + // An external process to request credentials + CredentialProcess string + // Region is the region the SDK should use for looking up AWS service endpoints // and signing requests. // @@ -223,6 +228,11 @@ func (cfg *sharedConfig) setFromIniFile(profile string, file sharedConfigFile) e } } + // `credential_process` + if credProc := section.String(credentialProcessKey); len(credProc) > 0 { + cfg.CredentialProcess = credProc + } + // Region if v := section.String(regionKey); len(v) > 0 { cfg.Region = v