Skip to content

Commit 2a77af3

Browse files
authored
Merge commit from fork
Signed-off-by: Jan Martens <[email protected]>
1 parent a5290d7 commit 2a77af3

File tree

4 files changed

+121
-26
lines changed

4 files changed

+121
-26
lines changed

auth/aws/backend.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,17 +73,17 @@ type backend struct {
7373
// of tidyCooldownPeriod.
7474
nextTidyTime time.Time
7575

76-
// Map to hold the EC2 client objects indexed by region and STS role.
76+
// Map to hold the EC2 client objects indexed by region, account ID, and STS role.
7777
// This avoids the overhead of creating a client object for every login request.
7878
// When the credentials are modified or deleted, all the cached client objects
7979
// will be flushed. The empty STS role signifies the master account
80-
EC2ClientsMap map[string]map[string]*ec2.EC2
80+
EC2ClientsMap map[string]map[string]map[string]*ec2.EC2
8181

82-
// Map to hold the IAM client objects indexed by region and STS role.
82+
// Map to hold the IAM client objects indexed by region, account ID, and STS role.
8383
// This avoids the overhead of creating a client object for every login request.
8484
// When the credentials are modified or deleted, all the cached client objects
8585
// will be flushed. The empty STS role signifies the master account
86-
IAMClientsMap map[string]map[string]*iam.IAM
86+
IAMClientsMap map[string]map[string]map[string]*iam.IAM
8787

8888
// Map to associate a partition to a random region in that partition. Users of
8989
// this don't care what region in the partition they use, but there is some client
@@ -122,8 +122,8 @@ func Backend(_ *logical.BackendConfig) (*backend, error) {
122122
// Setting the periodic func to be run once in an hour.
123123
// If there is a real need, this can be made configurable.
124124
tidyCooldownPeriod: time.Hour,
125-
EC2ClientsMap: make(map[string]map[string]*ec2.EC2),
126-
IAMClientsMap: make(map[string]map[string]*iam.IAM),
125+
EC2ClientsMap: make(map[string]map[string]map[string]*ec2.EC2),
126+
IAMClientsMap: make(map[string]map[string]map[string]*iam.IAM),
127127
iamUserIdToArnCache: cache.New(7*24*time.Hour, 24*time.Hour),
128128
tidyDenyListCASGuard: new(uint32),
129129
tidyAccessListCASGuard: new(uint32),

auth/aws/client.go

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,12 @@ func (b *backend) stsRoleForAccount(ctx context.Context, s logical.Storage, acco
190190
if sts != nil {
191191
return sts.StsRole, nil
192192
}
193+
194+
// Return an error if there's no STS config for an account which is not the default one
195+
if b.defaultAWSAccountID != "" && b.defaultAWSAccountID != accountID {
196+
return "", fmt.Errorf("no STS configuration found for account ID %q", accountID)
197+
}
198+
193199
return "", nil
194200
}
195201

@@ -200,20 +206,26 @@ func (b *backend) clientEC2(ctx context.Context, s logical.Storage, region, acco
200206
return nil, err
201207
}
202208
b.configMutex.RLock()
203-
if b.EC2ClientsMap[region] != nil && b.EC2ClientsMap[region][stsRole] != nil {
209+
if b.EC2ClientsMap[region] != nil &&
210+
b.EC2ClientsMap[region][accountID] != nil &&
211+
b.EC2ClientsMap[region][accountID][stsRole] != nil {
204212
defer b.configMutex.RUnlock()
205213
// If the client object was already created, return it
206-
return b.EC2ClientsMap[region][stsRole], nil
214+
b.Logger().Debug(fmt.Sprintf("returning cached client for region %s, account %s and stsRole %s", region, accountID, stsRole))
215+
return b.EC2ClientsMap[region][accountID][stsRole], nil
207216
}
217+
b.Logger().Debug(fmt.Sprintf("no cached client for region %s, account %s and stsRole %s", region, accountID, stsRole))
208218

209219
// Release the read lock and acquire the write lock
210220
b.configMutex.RUnlock()
211221
b.configMutex.Lock()
212222
defer b.configMutex.Unlock()
213223

214224
// If the client gets created while switching the locks, return it
215-
if b.EC2ClientsMap[region] != nil && b.EC2ClientsMap[region][stsRole] != nil {
216-
return b.EC2ClientsMap[region][stsRole], nil
225+
if b.EC2ClientsMap[region] != nil &&
226+
b.EC2ClientsMap[region][accountID] != nil &&
227+
b.EC2ClientsMap[region][accountID][stsRole] != nil {
228+
return b.EC2ClientsMap[region][accountID][stsRole], nil
217229
}
218230

219231
// Create an AWS config object using a chain of providers
@@ -237,13 +249,16 @@ func (b *backend) clientEC2(ctx context.Context, s logical.Storage, region, acco
237249
if client == nil {
238250
return nil, fmt.Errorf("could not obtain ec2 client")
239251
}
252+
240253
if _, ok := b.EC2ClientsMap[region]; !ok {
241-
b.EC2ClientsMap[region] = map[string]*ec2.EC2{stsRole: client}
242-
} else {
243-
b.EC2ClientsMap[region][stsRole] = client
254+
b.EC2ClientsMap[region] = make(map[string]map[string]*ec2.EC2)
244255
}
256+
if _, ok := b.EC2ClientsMap[region][accountID]; !ok {
257+
b.EC2ClientsMap[region][accountID] = make(map[string]*ec2.EC2)
258+
}
259+
b.EC2ClientsMap[region][accountID][stsRole] = client
245260

246-
return b.EC2ClientsMap[region][stsRole], nil
261+
return b.EC2ClientsMap[region][accountID][stsRole], nil
247262
}
248263

249264
// clientIAM creates a client to interact with AWS IAM API
@@ -258,22 +273,26 @@ func (b *backend) clientIAM(ctx context.Context, s logical.Storage, region, acco
258273
b.Logger().Debug(fmt.Sprintf("found stsRole %s for account %s", stsRole, accountID))
259274
}
260275
b.configMutex.RLock()
261-
if b.IAMClientsMap[region] != nil && b.IAMClientsMap[region][stsRole] != nil {
276+
if b.IAMClientsMap[region] != nil &&
277+
b.IAMClientsMap[region][accountID] != nil &&
278+
b.IAMClientsMap[region][accountID][stsRole] != nil {
262279
defer b.configMutex.RUnlock()
263280
// If the client object was already created, return it
264-
b.Logger().Debug(fmt.Sprintf("returning cached client for region %s and stsRole %s", region, stsRole))
265-
return b.IAMClientsMap[region][stsRole], nil
281+
b.Logger().Debug(fmt.Sprintf("returning cached client for region %s, account %s and stsRole %s", region, accountID, stsRole))
282+
return b.IAMClientsMap[region][accountID][stsRole], nil
266283
}
267-
b.Logger().Debug(fmt.Sprintf("no cached client for region %s and stsRole %s", region, stsRole))
284+
b.Logger().Debug(fmt.Sprintf("no cached client for region %s, account %s and stsRole %s", region, accountID, stsRole))
268285

269286
// Release the read lock and acquire the write lock
270287
b.configMutex.RUnlock()
271288
b.configMutex.Lock()
272289
defer b.configMutex.Unlock()
273290

274291
// If the client gets created while switching the locks, return it
275-
if b.IAMClientsMap[region] != nil && b.IAMClientsMap[region][stsRole] != nil {
276-
return b.IAMClientsMap[region][stsRole], nil
292+
if b.IAMClientsMap[region] != nil &&
293+
b.IAMClientsMap[region][accountID] != nil &&
294+
b.IAMClientsMap[region][accountID][stsRole] != nil {
295+
return b.IAMClientsMap[region][accountID][stsRole], nil
277296
}
278297

279298
// Create an AWS config object using a chain of providers
@@ -297,10 +316,14 @@ func (b *backend) clientIAM(ctx context.Context, s logical.Storage, region, acco
297316
if client == nil {
298317
return nil, fmt.Errorf("could not obtain iam client")
299318
}
319+
300320
if _, ok := b.IAMClientsMap[region]; !ok {
301-
b.IAMClientsMap[region] = map[string]*iam.IAM{stsRole: client}
302-
} else {
303-
b.IAMClientsMap[region][stsRole] = client
321+
b.IAMClientsMap[region] = make(map[string]map[string]*iam.IAM)
304322
}
305-
return b.IAMClientsMap[region][stsRole], nil
323+
if _, ok := b.IAMClientsMap[region][accountID]; !ok {
324+
b.IAMClientsMap[region][accountID] = make(map[string]*iam.IAM)
325+
}
326+
b.IAMClientsMap[region][accountID][stsRole] = client
327+
328+
return b.IAMClientsMap[region][accountID][stsRole], nil
306329
}

auth/aws/client_test.go

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
// Copyright (c) 2025 OpenBao a Series of LF Projects, LLC
2+
// SPDX-License-Identifier: MPL-2.0
3+
4+
package aws
5+
6+
import (
7+
"context"
8+
"fmt"
9+
"testing"
10+
11+
"github.com/openbao/openbao/sdk/v2/logical"
12+
)
13+
14+
// TestClientCache verifies that IAM clients for different
15+
// AWS accounts are properly isolated in the cache
16+
func TestClientCache(t *testing.T) {
17+
config := logical.TestBackendConfig()
18+
storage := &logical.InmemStorage{}
19+
config.StorageView = storage
20+
21+
b, err := Backend(config)
22+
if err != nil {
23+
t.Fatal(err)
24+
}
25+
26+
ctx := context.Background()
27+
if err := b.Setup(ctx, config); err != nil {
28+
t.Fatal(err)
29+
}
30+
31+
account1 := "111111111111"
32+
account2 := "222222222222"
33+
34+
b.defaultAWSAccountID = account1
35+
36+
// This should work - same account as default
37+
stsRole, err := b.stsRoleForAccount(ctx, storage, account1)
38+
if err != nil {
39+
t.Fatalf("Expected success for default account, got error: %v", err)
40+
}
41+
if stsRole != "" {
42+
t.Fatalf("Expected empty STS role for default account, got: %v", stsRole)
43+
}
44+
45+
// This should fail - different account without STS config
46+
_, err = b.stsRoleForAccount(ctx, storage, account2)
47+
if err == nil {
48+
t.Fatal("Expected error for cross-account access without STS config")
49+
}
50+
51+
// Verify the error message contains the expected error
52+
expectedError := fmt.Sprintf("no STS configuration found for account ID %q", account2)
53+
if err.Error() != expectedError {
54+
t.Fatalf("Expected specific error message, got: %v", err)
55+
}
56+
57+
stsEntry := &awsStsEntry{
58+
StsRole: "arn:aws:iam::222222222222:role/cross-account-role",
59+
}
60+
err = b.lockedSetAwsStsEntry(ctx, storage, account2, stsEntry)
61+
if err != nil {
62+
t.Fatalf("Failed to set STS entry: %v", err)
63+
}
64+
65+
stsRole, err = b.stsRoleForAccount(ctx, storage, account2)
66+
if err != nil {
67+
t.Fatalf("Expected success for account with STS config, got error: %v", err)
68+
}
69+
if stsRole != stsEntry.StsRole {
70+
t.Fatalf("Expected STS role %v, got: %v", stsEntry.StsRole, stsRole)
71+
}
72+
}

auth/aws/path_config_rotate_root.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,8 @@ func (b *backend) pathConfigRotateRootUpdate(ctx context.Context, req *logical.R
177177

178178
// Previous cached clients need to be cleared because they may have been made using
179179
// the soon-to-be-obsolete credentials.
180-
b.IAMClientsMap = make(map[string]map[string]*iam.IAM)
181-
b.EC2ClientsMap = make(map[string]map[string]*ec2.EC2)
180+
b.IAMClientsMap = make(map[string]map[string]map[string]*iam.IAM)
181+
b.EC2ClientsMap = make(map[string]map[string]map[string]*ec2.EC2)
182182

183183
// Now to clean up the old key.
184184
deleteAccessKeyInput := iam.DeleteAccessKeyInput{

0 commit comments

Comments
 (0)