Skip to content

Commit 15c1f15

Browse files
chore: task role assume role (#5476)
Implements the first of two planned methods to retrieve TaskRole credentials for `run local` containers By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the Apache 2.0 License.
1 parent dae144e commit 15c1f15

4 files changed

Lines changed: 182 additions & 0 deletions

File tree

internal/pkg/cli/errors.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,25 @@ func (e *errPipelineDependsOnEnv) RecommendActions() string {
163163
or run %s to delete the pipeline before running %s to delete the environment`,
164164
e.pipeline, e.env, color.HighlightCode(fmt.Sprintf("copilot pipeline delete -n %s", e.pipeline)), color.HighlightCode(fmt.Sprintf("copilot env delete -n %s", e.env)))
165165
}
166+
167+
type errTaskRoleRetrievalFailed struct {
168+
chainErrs []error
169+
}
170+
171+
func (e *errTaskRoleRetrievalFailed) Error() string {
172+
return errors.Join(e.chainErrs...).Error()
173+
}
174+
175+
func (e *errTaskRoleRetrievalFailed) RecommendActions() string {
176+
return fmt.Sprintf(`TaskRole retrieval failed. You can manually add permissions for your account to assume TaskRole by adding the following YAML override to your service:
177+
%s
178+
For more information on YAML overrides see %s`,
179+
color.HighlightCodeBlock(`- op: add
180+
path: /Resources/TaskRole/Properties/AssumeRolePolicyDocument/Statement/-
181+
value:
182+
Effect: Allow
183+
Principal:
184+
AWS: "arn:aws:iam::[app-account-ID]:root"
185+
Action: 'sts:AssumeRole'`),
186+
color.Emphasize("https://aws.github.io/copilot-cli/docs/developing/overrides/yamlpatch/"))
187+
}

internal/pkg/cli/flag.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ const (
7272
proxyFlag = "proxy"
7373
proxyNetworkFlag = "proxy-network"
7474
watchFlag = "watch"
75+
useTaskRoleFlag = "use-task-role"
7576

7677
// Flags for CI/CD.
7778
githubURLFlag = "github-url"
@@ -326,6 +327,7 @@ Example: --port-override 5000:80 binds localhost:5000 to the service's port 80.`
326327
proxyFlagDescription = `Optional. Proxy outbound requests to your environment's VPC.`
327328
proxyNetworkFlagDescription = `Optional. Set the IP Network used by --proxy.`
328329
watchFlagDescription = `Optional. Watch changes to local files and restart containers when updated.`
330+
useTaskRoleFlagDescription = "Optional. Run containers with TaskRole credentials instead of session credentials."
329331

330332
svcManifestFlagDescription = `Optional. Name of the environment in which the service was deployed;
331333
output the manifest file used for that deployment.`

internal/pkg/cli/run_local.go

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ type runLocalVars struct {
9898
envName string
9999
envOverrides map[string]string
100100
watch bool
101+
useTaskRole bool
101102
portOverrides portOverrides
102103
proxy bool
103104
proxyNetwork net.IPNet
@@ -436,6 +437,23 @@ func (o *runLocalOpts) getTask(ctx context.Context) (orchestrator.Task, error) {
436437
return orchestrator.Task{}, fmt.Errorf("get env vars: %w", err)
437438
}
438439

440+
if o.useTaskRole {
441+
taskRoleCredsVars, err := o.taskRoleCredentials(ctx)
442+
if err != nil {
443+
return orchestrator.Task{}, fmt.Errorf("retrieve task role credentials: %w", err)
444+
}
445+
446+
// overwrite environment variables
447+
for ctr := range envVars {
448+
for k, v := range taskRoleCredsVars {
449+
envVars[ctr][k] = envVarValue{
450+
Value: v,
451+
Secret: true,
452+
}
453+
}
454+
}
455+
}
456+
439457
containerDeps := o.getContainerDependencies(td)
440458

441459
task := orchestrator.Task{
@@ -629,6 +647,46 @@ func sessionEnvVars(ctx context.Context, sess *session.Session) (map[string]stri
629647
return env, nil
630648
}
631649

650+
func (o *runLocalOpts) taskRoleCredentials(ctx context.Context) (map[string]string, error) {
651+
// assumeRoleMethod tries to directly call sts:AssumeRole for TaskRole using default session
652+
// calls sts:AssumeRole through aws-sdk-go here https://github.com/aws/aws-sdk-go/blob/ac58203a9054cc9d901429bdd94edfc0a7a1de46/aws/credentials/stscreds/assume_role_provider.go#L352
653+
assumeRoleMethod := func() (map[string]string, error) {
654+
taskDef, err := o.ecsClient.TaskDefinition(o.appName, o.envName, o.wkldName)
655+
if err != nil {
656+
return nil, err
657+
}
658+
659+
taskRoleSess, err := o.sessProvider.FromRole(aws.StringValue(taskDef.TaskRoleArn), o.targetEnv.Region)
660+
if err != nil {
661+
return nil, err
662+
}
663+
664+
return sessionEnvVars(ctx, taskRoleSess)
665+
}
666+
667+
// ecsExecMethod tries to use ECS Exec to retrive credentials from running container
668+
ecsExecMethod := func() (map[string]string, error) {
669+
return nil, errors.New("ecs exec method not implemented")
670+
}
671+
672+
credentialsChain := []func() (map[string]string, error){
673+
assumeRoleMethod,
674+
ecsExecMethod,
675+
}
676+
677+
// return TaskRole credentials from first successful method
678+
var errs []error
679+
for _, method := range credentialsChain {
680+
vars, err := method()
681+
if err == nil {
682+
return vars, nil
683+
}
684+
errs = append(errs, err)
685+
}
686+
687+
return nil, &errTaskRoleRetrievalFailed{errs}
688+
}
689+
632690
type containerEnv map[string]envVarValue
633691

634692
type envVarValue struct {

internal/pkg/cli/run_local_test.go

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ type runLocalExecuteMocks struct {
203203
ecsClient *mocks.MockecsClient
204204
store *mocks.Mockstore
205205
sessCreds credentials.Provider
206+
sessProvider *mocks.MocksessionProvider
206207
interpolator *mocks.Mockinterpolator
207208
ws *mocks.MockwsWlDirReader
208209
mockMft *mockWorkloadMft
@@ -268,6 +269,7 @@ func TestRunLocalOpts_Execute(t *testing.T) {
268269
}
269270

270271
taskDef := &awsecs.TaskDefinition{
272+
TaskRoleArn: aws.String("mock-arn"),
271273
ContainerDefinitions: []*sdkecs.ContainerDefinition{
272274
{
273275
Name: aws.String("foo"),
@@ -328,6 +330,7 @@ func TestRunLocalOpts_Execute(t *testing.T) {
328330
},
329331
}
330332
alteredTaskDef := &awsecs.TaskDefinition{
333+
TaskRoleArn: aws.String("mock-arn"),
331334
ContainerDefinitions: []*sdkecs.ContainerDefinition{
332335
{
333336
Name: aws.String("foo"),
@@ -429,6 +432,52 @@ func TestRunLocalOpts_Execute(t *testing.T) {
429432
"AWS_SESSION_TOKEN": "myEnvToken",
430433
},
431434
}
435+
expectedTaskRoleTask := orchestrator.Task{
436+
Containers: map[string]orchestrator.ContainerDefinition{
437+
"foo": {
438+
ImageURI: "image1",
439+
EnvVars: map[string]string{
440+
"FOO_VAR": "foo-value",
441+
},
442+
Secrets: map[string]string{
443+
"SHARED_SECRET": "secretvalue",
444+
"AWS_ACCESS_KEY_ID": "taskRoleID",
445+
"AWS_SECRET_ACCESS_KEY": "taskRoleSecret",
446+
"AWS_SESSION_TOKEN": "taskRoleToken",
447+
"AWS_DEFAULT_REGION": testRegion,
448+
"AWS_REGION": testRegion,
449+
},
450+
Ports: map[string]string{
451+
"80": "8080",
452+
"999": "9999",
453+
},
454+
IsEssential: true,
455+
DependsOn: map[string]string{
456+
"bar": "start",
457+
},
458+
},
459+
"bar": {
460+
ImageURI: "image2",
461+
EnvVars: map[string]string{
462+
"BAR_VAR": "bar-value",
463+
},
464+
Secrets: map[string]string{
465+
"SHARED_SECRET": "secretvalue",
466+
"AWS_ACCESS_KEY_ID": "taskRoleID",
467+
"AWS_SECRET_ACCESS_KEY": "taskRoleSecret",
468+
"AWS_SESSION_TOKEN": "taskRoleToken",
469+
"AWS_DEFAULT_REGION": testRegion,
470+
"AWS_REGION": testRegion,
471+
},
472+
Ports: map[string]string{
473+
"777": "7777",
474+
"10000": "10000",
475+
},
476+
IsEssential: true,
477+
DependsOn: map[string]string{},
478+
},
479+
},
480+
}
432481

433482
testCases := map[string]struct {
434483
inputAppName string
@@ -437,6 +486,7 @@ func TestRunLocalOpts_Execute(t *testing.T) {
437486
inputEnvOverrides map[string]string
438487
inputPortOverrides []string
439488
inputWatch bool
489+
inputTaskRole bool
440490
inputProxy bool
441491
buildImagesError error
442492

@@ -467,6 +517,20 @@ func TestRunLocalOpts_Execute(t *testing.T) {
467517
},
468518
wantedError: errors.New(`get task: get env vars: parse env overrides: "bad:OVERRIDE" targets invalid container`),
469519
},
520+
"error retrieving task role credentials": {
521+
inputAppName: testAppName,
522+
inputWkldName: testWkldName,
523+
inputEnvName: testEnvName,
524+
inputTaskRole: true,
525+
setupMocks: func(t *testing.T, m *runLocalExecuteMocks) {
526+
m.ecsClient.EXPECT().TaskDefinition(testAppName, testEnvName, testWkldName).Return(taskDef, nil)
527+
m.ssm.EXPECT().GetSecretValue(gomock.Any(), "mysecret").Return("secretvalue", nil)
528+
m.ecsClient.EXPECT().TaskDefinition(testAppName, testEnvName, testWkldName).Return(taskDef, nil)
529+
m.sessProvider.EXPECT().FromRole("mock-arn", testRegion).Return(nil, errors.New("some error"))
530+
},
531+
wantedError: errors.New(`get task: retrieve task role credentials: some error
532+
ecs exec method not implemented`),
533+
},
470534
"error reading workload manifest": {
471535
inputAppName: testAppName,
472536
inputWkldName: testWkldName,
@@ -743,6 +807,39 @@ func TestRunLocalOpts_Execute(t *testing.T) {
743807
}
744808
},
745809
},
810+
"success, one run task call, taskrole assumerole method": {
811+
inputAppName: testAppName,
812+
inputWkldName: testWkldName,
813+
inputEnvName: testEnvName,
814+
inputTaskRole: true,
815+
setupMocks: func(t *testing.T, m *runLocalExecuteMocks) {
816+
m.ecsClient.EXPECT().TaskDefinition(testAppName, testEnvName, testWkldName).Return(taskDef, nil)
817+
m.ssm.EXPECT().GetSecretValue(gomock.Any(), "mysecret").Return("secretvalue", nil)
818+
m.ecsClient.EXPECT().TaskDefinition(testAppName, testEnvName, testWkldName).Return(taskDef, nil)
819+
taskRoleSess := &session.Session{
820+
Config: &aws.Config{
821+
Credentials: credentials.NewStaticCredentials("taskRoleID", "taskRoleSecret", "taskRoleToken"),
822+
Region: aws.String(testRegion),
823+
},
824+
}
825+
m.sessProvider.EXPECT().FromRole("mock-arn", testRegion).Return(taskRoleSess, nil)
826+
m.ws.EXPECT().ReadWorkloadManifest(testWkldName).Return([]byte(""), nil)
827+
m.interpolator.EXPECT().Interpolate("").Return("", nil)
828+
829+
errCh := make(chan error, 1)
830+
m.orchestrator.StartFn = func() <-chan error {
831+
errCh <- errors.New("some error")
832+
return errCh
833+
}
834+
m.orchestrator.RunTaskFn = func(task orchestrator.Task, opts ...orchestrator.RunTaskOption) {
835+
require.Equal(t, expectedTaskRoleTask, task)
836+
}
837+
m.orchestrator.StopFn = func() {
838+
require.Len(t, errCh, 0)
839+
close(errCh)
840+
}
841+
},
842+
},
746843
"handles ctrl-c, waits to get all errors": {
747844
inputAppName: testAppName,
748845
inputWkldName: testWkldName,
@@ -959,6 +1056,7 @@ func TestRunLocalOpts_Execute(t *testing.T) {
9591056
ssm: mocks.NewMocksecretGetter(ctrl),
9601057
secretsManager: mocks.NewMocksecretGetter(ctrl),
9611058
store: mocks.NewMockstore(ctrl),
1059+
sessProvider: mocks.NewMocksessionProvider(ctrl),
9621060
interpolator: mocks.NewMockinterpolator(ctrl),
9631061
ws: mocks.NewMockwsWlDirReader(ctrl),
9641062
mockRunner: mocks.NewMockexecRunner(ctrl),
@@ -978,6 +1076,7 @@ func TestRunLocalOpts_Execute(t *testing.T) {
9781076
envName: tc.inputEnvName,
9791077
envOverrides: tc.inputEnvOverrides,
9801078
watch: tc.inputWatch,
1079+
useTaskRole: tc.inputTaskRole,
9811080
portOverrides: portOverrides{
9821081
{
9831082
host: "777",
@@ -1007,6 +1106,7 @@ func TestRunLocalOpts_Execute(t *testing.T) {
10071106
ssm: m.ssm,
10081107
secretsManager: m.secretsManager,
10091108
store: m.store,
1109+
sessProvider: m.sessProvider,
10101110
sess: &session.Session{
10111111
Config: &aws.Config{
10121112
Credentials: credentials.NewStaticCredentials("myID", "mySecret", "myToken"),

0 commit comments

Comments
 (0)