Skip to content
This repository was archived by the owner on Jul 31, 2025. It is now read-only.

Commit e3bbdfb

Browse files
committed
Implement Credential Chain Support for SSO Provider
1 parent 308d285 commit e3bbdfb

File tree

7 files changed

+357
-39
lines changed

7 files changed

+357
-39
lines changed

aws/session/credentials.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"github.com/aws/aws-sdk-go/aws/awserr"
1010
"github.com/aws/aws-sdk-go/aws/credentials"
1111
"github.com/aws/aws-sdk-go/aws/credentials/processcreds"
12+
"github.com/aws/aws-sdk-go/aws/credentials/ssocreds"
1213
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
1314
"github.com/aws/aws-sdk-go/aws/defaults"
1415
"github.com/aws/aws-sdk-go/aws/request"
@@ -100,6 +101,9 @@ func resolveCredsFromProfile(cfg *aws.Config,
100101
sharedCfg.Creds,
101102
)
102103

104+
case sharedCfg.hasSSOConfiguration():
105+
creds = resolveSSOCredentials(cfg, sharedCfg, handlers)
106+
103107
case len(sharedCfg.CredentialProcess) != 0:
104108
// Get credentials from CredentialProcess
105109
creds = processcreds.NewCredentials(sharedCfg.CredentialProcess)
@@ -151,6 +155,21 @@ func resolveCredsFromProfile(cfg *aws.Config,
151155
return creds, nil
152156
}
153157

158+
func resolveSSOCredentials(cfg *aws.Config, sharedCfg sharedConfig, handlers request.Handlers) *credentials.Credentials {
159+
cfgCopy := cfg.Copy()
160+
cfgCopy.Region = &sharedCfg.SSORegion
161+
162+
return ssocreds.NewCredentials(
163+
&Session{
164+
Config: cfgCopy,
165+
Handlers: handlers.Copy(),
166+
},
167+
sharedCfg.SSOAccountID,
168+
sharedCfg.SSORoleName,
169+
sharedCfg.SSOStartURL,
170+
)
171+
}
172+
154173
// valid credential source values
155174
const (
156175
credSourceEc2Metadata = "Ec2InstanceMetadata"

aws/session/credentials_test.go

Lines changed: 154 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@ package session
44

55
import (
66
"fmt"
7+
"io/ioutil"
78
"net/http"
89
"net/http/httptest"
910
"os"
11+
"path/filepath"
1012
"reflect"
1113
"runtime"
1214
"strconv"
@@ -68,6 +70,14 @@ func setupCredentialsEndpoints(t *testing.T) (endpoints.Resolver, func()) {
6870
Format("2006-01-02T15:04:05Z"))))
6971
}))
7072

73+
ssoServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
74+
w.Write([]byte(fmt.Sprintf(
75+
getRoleCredentialsResponse,
76+
time.Now().
77+
Add(15*time.Minute).
78+
UnixNano()/int64(time.Millisecond))))
79+
}))
80+
7181
resolver := endpoints.ResolverFunc(
7282
func(service, region string, opts ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) {
7383
switch service {
@@ -79,6 +89,10 @@ func setupCredentialsEndpoints(t *testing.T) (endpoints.Resolver, func()) {
7989
return endpoints.ResolvedEndpoint{
8090
URL: stsServer.URL,
8191
}, nil
92+
case "portal.sso":
93+
return endpoints.ResolvedEndpoint{
94+
URL: ssoServer.URL,
95+
}, nil
8296
default:
8397
return endpoints.ResolvedEndpoint{},
8498
fmt.Errorf("unknown service endpoint, %s", service)
@@ -89,6 +103,7 @@ func setupCredentialsEndpoints(t *testing.T) (endpoints.Resolver, func()) {
89103
shareddefaults.ECSContainerCredentialsURI = origECSEndpoint
90104
ecsMetadataServer.Close()
91105
ec2MetadataServer.Close()
106+
ssoServer.Close()
92107
stsServer.Close()
93108
}
94109
}
@@ -105,30 +120,34 @@ func TestSharedConfigCredentialSource(t *testing.T) {
105120
expectedError error
106121
expectedAccessKey string
107122
expectedSecretKey string
123+
expectedSessionToken string
108124
expectedChain []string
109-
init func()
125+
init func() (func(), error)
110126
dependentOnOS bool
111127
}{
112128
{
113129
name: "credential source and source profile",
114130
profile: "invalid_source_and_credential_source",
115131
expectedError: ErrSharedConfigSourceCollision,
116-
init: func() {
132+
init: func() (func(), error) {
117133
os.Setenv("AWS_ACCESS_KEY", "access_key")
118134
os.Setenv("AWS_SECRET_KEY", "secret_key")
135+
return func() {}, nil
119136
},
120137
},
121138
{
122-
name: "env var credential source",
123-
sessOptProfile: "env_var_credential_source",
124-
expectedAccessKey: "AKID",
125-
expectedSecretKey: "SECRET",
139+
name: "env var credential source",
140+
sessOptProfile: "env_var_credential_source",
141+
expectedAccessKey: "AKID",
142+
expectedSecretKey: "SECRET",
143+
expectedSessionToken: "SESSION_TOKEN",
126144
expectedChain: []string{
127145
"assume_role_w_creds_role_arn_env",
128146
},
129-
init: func() {
147+
init: func() (func(), error) {
130148
os.Setenv("AWS_ACCESS_KEY", "access_key")
131149
os.Setenv("AWS_SECRET_KEY", "secret_key")
150+
return func() {}, nil
132151
},
133152
},
134153
{
@@ -137,36 +156,42 @@ func TestSharedConfigCredentialSource(t *testing.T) {
137156
expectedChain: []string{
138157
"assume_role_w_creds_role_arn_ec2",
139158
},
140-
expectedAccessKey: "AKID",
141-
expectedSecretKey: "SECRET",
159+
expectedAccessKey: "AKID",
160+
expectedSecretKey: "SECRET",
161+
expectedSessionToken: "SESSION_TOKEN",
142162
},
143163
{
144-
name: "ec2metadata custom EC2 IMDS endpoint, env var",
145-
profile: "not-exists-profile",
146-
expectedAccessKey: "ec2_custom_key",
147-
expectedSecretKey: "ec2_custom_secret",
148-
init: func() {
164+
name: "ec2metadata custom EC2 IMDS endpoint, env var",
165+
profile: "not-exists-profile",
166+
expectedAccessKey: "ec2_custom_key",
167+
expectedSecretKey: "ec2_custom_secret",
168+
expectedSessionToken: "token",
169+
init: func() (func(), error) {
149170
altServer := newEc2MetadataServer("ec2_custom_key", "ec2_custom_secret", true)
150171
os.Setenv("AWS_EC2_METADATA_SERVICE_ENDPOINT", altServer.URL)
172+
return func() {}, nil
151173
},
152174
},
153175
{
154-
name: "ecs container credential source",
155-
profile: "ecscontainer",
156-
expectedAccessKey: "AKID",
157-
expectedSecretKey: "SECRET",
176+
name: "ecs container credential source",
177+
profile: "ecscontainer",
178+
expectedAccessKey: "AKID",
179+
expectedSecretKey: "SECRET",
180+
expectedSessionToken: "SESSION_TOKEN",
158181
expectedChain: []string{
159182
"assume_role_w_creds_role_arn_ecs",
160183
},
161-
init: func() {
184+
init: func() (func(), error) {
162185
os.Setenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/ECS")
186+
return func() {}, nil
163187
},
164188
},
165189
{
166-
name: "chained assume role with env creds",
167-
profile: "chained_assume_role",
168-
expectedAccessKey: "AKID",
169-
expectedSecretKey: "SECRET",
190+
name: "chained assume role with env creds",
191+
profile: "chained_assume_role",
192+
expectedAccessKey: "AKID",
193+
expectedSecretKey: "SECRET",
194+
expectedSessionToken: "SESSION_TOKEN",
170195
expectedChain: []string{
171196
"assume_role_w_creds_role_arn_chain",
172197
"assume_role_w_creds_role_arn_ec2",
@@ -180,25 +205,60 @@ func TestSharedConfigCredentialSource(t *testing.T) {
180205
expectedSecretKey: "cred_proc_secret",
181206
},
182207
{
183-
name: "credential process with ARN set",
184-
profile: "cred_proc_arn_set",
185-
dependentOnOS: true,
186-
expectedAccessKey: "AKID",
187-
expectedSecretKey: "SECRET",
208+
name: "credential process with ARN set",
209+
profile: "cred_proc_arn_set",
210+
dependentOnOS: true,
211+
expectedAccessKey: "AKID",
212+
expectedSecretKey: "SECRET",
213+
expectedSessionToken: "SESSION_TOKEN",
188214
expectedChain: []string{
189215
"assume_role_w_creds_proc_role_arn",
190216
},
191217
},
192218
{
193-
name: "chained assume role with credential process",
194-
profile: "chained_cred_proc",
195-
dependentOnOS: true,
196-
expectedAccessKey: "AKID",
197-
expectedSecretKey: "SECRET",
219+
name: "chained assume role with credential process",
220+
profile: "chained_cred_proc",
221+
dependentOnOS: true,
222+
expectedAccessKey: "AKID",
223+
expectedSecretKey: "SECRET",
224+
expectedSessionToken: "SESSION_TOKEN",
198225
expectedChain: []string{
199226
"assume_role_w_creds_proc_source_prof",
200227
},
201228
},
229+
{
230+
name: "sso credentials",
231+
profile: "sso_creds",
232+
expectedAccessKey: "SSO_AKID",
233+
expectedSecretKey: "SSO_SECRET_KEY",
234+
expectedSessionToken: "SSO_SESSION_TOKEN",
235+
init: func() (func(), error) {
236+
return ssoTestSetup()
237+
},
238+
},
239+
{
240+
name: "chained assume role with sso credentials",
241+
profile: "source_sso_creds",
242+
expectedAccessKey: "AKID",
243+
expectedSecretKey: "SECRET",
244+
expectedSessionToken: "SESSION_TOKEN",
245+
expectedChain: []string{
246+
"source_sso_creds_arn",
247+
},
248+
init: func() (func(), error) {
249+
return ssoTestSetup()
250+
},
251+
},
252+
{
253+
name: "chained assume role with sso and static credentials",
254+
profile: "assume_sso_and_static",
255+
expectedAccessKey: "AKID",
256+
expectedSecretKey: "SECRET",
257+
expectedSessionToken: "SESSION_TOKEN",
258+
expectedChain: []string{
259+
"assume_sso_and_static_arn",
260+
},
261+
},
202262
}
203263

204264
for i, c := range cases {
@@ -222,7 +282,11 @@ func TestSharedConfigCredentialSource(t *testing.T) {
222282
defer cleanupFn()
223283

224284
if c.init != nil {
225-
c.init()
285+
cleanup, err := c.init()
286+
if err != nil {
287+
t.Fatalf("expect no error, got %v", err)
288+
}
289+
defer cleanup()
226290
}
227291

228292
var credChain []string
@@ -268,6 +332,10 @@ func TestSharedConfigCredentialSource(t *testing.T) {
268332
if e, a := c.expectedSecretKey, creds.SecretAccessKey; e != a {
269333
t.Errorf("expected %v, but received %v", e, a)
270334
}
335+
336+
if e, a := c.expectedSessionToken, creds.SessionToken; e != a {
337+
t.Errorf("expected %v, but received %v", e, a)
338+
}
271339
})
272340
}
273341
}
@@ -312,6 +380,20 @@ const assumeRoleRespMsg = `
312380
</AssumeRoleResponse>
313381
`
314382

383+
const getRoleCredentialsResponse = `{
384+
"roleCredentials": {
385+
"accessKeyId": "SSO_AKID",
386+
"secretAccessKey": "SSO_SECRET_KEY",
387+
"sessionToken": "SSO_SESSION_TOKEN",
388+
"expiration": %d
389+
}
390+
}`
391+
392+
const ssoTokenCacheFile = `{
393+
"accessToken": "ssoAccessToken",
394+
"expiresAt": "%s"
395+
}`
396+
315397
func TestSessionAssumeRole(t *testing.T) {
316398
restoreEnvFn := initSessionTestEnv()
317399
defer restoreEnvFn()
@@ -647,3 +729,41 @@ func TestSessionAssumeRole_WithMFA_ExtendedDuration(t *testing.T) {
647729
t.Errorf("expect %v, to be in %v", e, a)
648730
}
649731
}
732+
733+
func ssoTestSetup() (func(), error) {
734+
dir, err := ioutil.TempDir("", "sso-test")
735+
if err != nil {
736+
return nil, err
737+
}
738+
739+
cacheDir := filepath.Join(dir, ".aws", "sso", "cache")
740+
err = os.MkdirAll(cacheDir, 0750)
741+
if err != nil {
742+
os.RemoveAll(dir)
743+
return nil, err
744+
}
745+
746+
tokenFile, err := os.Create(filepath.Join(cacheDir, "eb5e43e71ce87dd92ec58903d76debd8ee42aefd.json"))
747+
if err != nil {
748+
os.RemoveAll(dir)
749+
return nil, err
750+
}
751+
defer tokenFile.Close()
752+
753+
_, err = tokenFile.WriteString(fmt.Sprintf(ssoTokenCacheFile, time.Now().
754+
Add(15*time.Minute).
755+
Format(time.RFC3339)))
756+
if err != nil {
757+
os.RemoveAll(dir)
758+
return nil, err
759+
}
760+
761+
if runtime.GOOS == "windows" {
762+
os.Setenv("USERPROFILE", dir)
763+
} else {
764+
os.Setenv("HOME", dir)
765+
}
766+
767+
return func() {
768+
}, nil
769+
}

aws/session/session.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ const (
3636

3737
// ErrSharedConfigSourceCollision will be returned if a section contains both
3838
// source_profile and credential_source
39-
var ErrSharedConfigSourceCollision = awserr.New(ErrCodeSharedConfig, "only source profile or credential source can be specified, not both", nil)
39+
var ErrSharedConfigSourceCollision = awserr.New(ErrCodeSharedConfig, "only one credential type may be specified per profile: source profile, credential source, credential process, web identity token, or sso", nil)
4040

4141
// ErrSharedConfigECSContainerEnvVarEmpty will be returned if the environment
4242
// variables are empty and Environment was set as the credential source

0 commit comments

Comments
 (0)