Skip to content

Commit a65d913

Browse files
Christopher Swensonjasonodonnell
andauthored
database: Avoid race condition in connection creation (#26147)
When creating database connections, there is a race condition when multiple goroutines try to create the connection at the same time. This happens, for example, on leadership changes in a cluster. Normally, the extra database connections are cleaned up when this is detected. However, some database implementations, notably Postgres, do not seem to clean up in a timely manner, and can leak in these scenarios. To fix this, we create a global lock when creating database connections to prevent multiple connections from being created at the same time. We also clean up the logic at the end so that if (somehow) we ended up creating an additional connection, we use the existing one rather than the new one. This by itself would solve our problem long-term, however, would still involve many transient database connections being created and immediately killed on leadership changes. It's not ideal to have a single global lock for database connection creation. Some potential alternatives: * a map of locks from the connection name to the lock. The biggest downside is the we probably will want to garbage collect this map so that we don't have an unbounded number of locks. * a small pool of locks, where we hash the connection names to pick the lock. Using such a pool generally is a good way to introduce deadlock, but since we will only use it in a specific case, and the purpose is to improve performance for concurrent connection creation, this is probably acceptable. Co-authored-by: Jason O'Donnell <[email protected]>
1 parent 7bd75eb commit a65d913

File tree

6 files changed

+158
-12
lines changed

6 files changed

+158
-12
lines changed

builtin/logical/database/backend.go

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,9 @@ func (b *databaseBackend) collectPluginInstanceGaugeValues(context.Context) ([]m
161161

162162
type databaseBackend struct {
163163
// connections holds configured database connections by config name
164-
connections *syncmap.SyncMap[string, *dbPluginInstance]
165-
logger log.Logger
164+
createConnectionLock sync.Mutex
165+
connections *syncmap.SyncMap[string, *dbPluginInstance]
166+
logger log.Logger
166167

167168
*framework.Backend
168169
// credRotationQueue is an in-memory priority queue used to track Static Roles
@@ -291,11 +292,23 @@ func (b *databaseBackend) GetConnection(ctx context.Context, s logical.Storage,
291292
}
292293

293294
func (b *databaseBackend) GetConnectionWithConfig(ctx context.Context, name string, config *DatabaseConfig) (*dbPluginInstance, error) {
295+
// fast path, reuse the existing connection
294296
dbi := b.connections.Get(name)
295297
if dbi != nil {
296298
return dbi, nil
297299
}
298300

301+
// slow path, create a new connection
302+
// if we don't lock the rest of the operation, there is a race condition for multiple callers of this function
303+
b.createConnectionLock.Lock()
304+
defer b.createConnectionLock.Unlock()
305+
306+
// check again in case we lost the race
307+
dbi = b.connections.Get(name)
308+
if dbi != nil {
309+
return dbi, nil
310+
}
311+
299312
id, err := uuid.GenerateUUID()
300313
if err != nil {
301314
return nil, err
@@ -332,14 +345,17 @@ func (b *databaseBackend) GetConnectionWithConfig(ctx context.Context, name stri
332345
name: name,
333346
runningPluginVersion: pluginVersion,
334347
}
335-
oldConn := b.connections.Put(name, dbi)
336-
if oldConn != nil {
337-
err := oldConn.Close()
348+
conn, ok := b.connections.PutIfEmpty(name, dbi)
349+
if !ok {
350+
// this is a bug
351+
b.Logger().Warn("BUG: there was a race condition adding to the database connection map")
352+
// There was already an existing connection, so we will use that and close our new one to avoid a race condition.
353+
err := dbi.Close()
338354
if err != nil {
339-
b.Logger().Warn("Error closing database connection", "error", err)
355+
b.Logger().Warn("Error closing new database connection", "error", err)
340356
}
341357
}
342-
return dbi, nil
358+
return conn, nil
343359
}
344360

345361
// ClearConnection closes the database connection and
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
// Copyright (c) HashiCorp, Inc.
2+
// SPDX-License-Identifier: BUSL-1.1
3+
4+
package database
5+
6+
import (
7+
"context"
8+
"sync"
9+
"testing"
10+
11+
"github.com/hashicorp/vault/sdk/helper/consts"
12+
"github.com/hashicorp/vault/sdk/helper/pluginutil"
13+
"github.com/hashicorp/vault/sdk/logical"
14+
"github.com/hashicorp/vault/sdk/queue"
15+
)
16+
17+
func newSystemViewWrapper(view logical.SystemView) logical.SystemView {
18+
return &systemViewWrapper{
19+
view,
20+
}
21+
}
22+
23+
type systemViewWrapper struct {
24+
logical.SystemView
25+
}
26+
27+
var _ logical.ExtendedSystemView = (*systemViewWrapper)(nil)
28+
29+
func (s *systemViewWrapper) RequestWellKnownRedirect(ctx context.Context, src, dest string) error {
30+
panic("nope")
31+
}
32+
33+
func (s *systemViewWrapper) DeregisterWellKnownRedirect(ctx context.Context, src string) bool {
34+
panic("nope")
35+
}
36+
37+
func (s *systemViewWrapper) Auditor() logical.Auditor {
38+
panic("nope")
39+
}
40+
41+
func (s *systemViewWrapper) ForwardGenericRequest(ctx context.Context, request *logical.Request) (*logical.Response, error) {
42+
panic("nope")
43+
}
44+
45+
func (s *systemViewWrapper) APILockShouldBlockRequest() (bool, error) {
46+
panic("nope")
47+
}
48+
49+
func (s *systemViewWrapper) GetPinnedPluginVersion(ctx context.Context, pluginType consts.PluginType, pluginName string) (*pluginutil.PinnedVersion, error) {
50+
return nil, pluginutil.ErrPinnedVersionNotFound
51+
}
52+
53+
func (s *systemViewWrapper) LookupPluginVersion(ctx context.Context, pluginName string, pluginType consts.PluginType, version string) (*pluginutil.PluginRunner, error) {
54+
return &pluginutil.PluginRunner{
55+
Name: mockv5,
56+
Type: consts.PluginTypeDatabase,
57+
Builtin: true,
58+
BuiltinFactory: New,
59+
}, nil
60+
}
61+
62+
func getDbBackend(t *testing.T) (*databaseBackend, logical.Storage) {
63+
t.Helper()
64+
config := logical.TestBackendConfig()
65+
config.System = newSystemViewWrapper(config.System)
66+
config.StorageView = &logical.InmemStorage{}
67+
// Create and init the backend ourselves instead of using a Factory because
68+
// the factory function kicks off threads that cause racy tests.
69+
b := Backend(config)
70+
if err := b.Setup(context.Background(), config); err != nil {
71+
t.Fatal(err)
72+
}
73+
b.schedule = &TestSchedule{}
74+
b.credRotationQueue = queue.New()
75+
b.populateQueue(context.Background(), config.StorageView)
76+
77+
return b, config.StorageView
78+
}
79+
80+
// TestGetConnectionRaceCondition checks that GetConnection always returns the same instance, even when asked
81+
// by multiple goroutines in parallel.
82+
func TestGetConnectionRaceCondition(t *testing.T) {
83+
ctx := context.Background()
84+
b, s := getDbBackend(t)
85+
defer b.Cleanup(ctx)
86+
configureDBMount(t, s)
87+
88+
goroutines := 16
89+
90+
wg := sync.WaitGroup{}
91+
wg.Add(goroutines)
92+
dbis := make([]*dbPluginInstance, goroutines)
93+
errs := make([]error, goroutines)
94+
for i := 0; i < goroutines; i++ {
95+
go func(i int) {
96+
defer wg.Done()
97+
dbis[i], errs[i] = b.GetConnection(ctx, s, mockv5)
98+
}(i)
99+
}
100+
wg.Wait()
101+
for i := 0; i < goroutines; i++ {
102+
if errs[i] != nil {
103+
t.Fatal(errs[i])
104+
}
105+
if dbis[0] != dbis[i] {
106+
t.Fatal("Error: database instances did not match")
107+
}
108+
}
109+
}

builtin/logical/database/mockv5.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ func (m MockDatabaseV5) Initialize(ctx context.Context, req v5.InitializeRequest
5151
"req", req)
5252

5353
config := req.Config
54+
if config == nil {
55+
config = map[string]interface{}{}
56+
}
5457
config["from-plugin"] = "this value is from the plugin itself"
5558

5659
resp := v5.InitializeResponse{

builtin/logical/database/rotation_test.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import (
3636
)
3737

3838
const (
39+
mockv5 = "mockv5"
3940
dbUser = "vaultstatictest"
4041
dbUserDefaultPassword = "password"
4142
testMinRotationWindowSeconds = 5
@@ -1446,7 +1447,7 @@ func TestStoredWALsCorrectlyProcessed(t *testing.T) {
14461447

14471448
rotationPeriodData := map[string]interface{}{
14481449
"username": "hashicorp",
1449-
"db_name": "mockv5",
1450+
"db_name": mockv5,
14501451
"rotation_period": "86400s",
14511452
}
14521453

@@ -1500,7 +1501,7 @@ func TestStoredWALsCorrectlyProcessed(t *testing.T) {
15001501
},
15011502
map[string]interface{}{
15021503
"username": "hashicorp",
1503-
"db_name": "mockv5",
1504+
"db_name": mockv5,
15041505
"rotation_schedule": "*/10 * * * * *",
15051506
},
15061507
},
@@ -1699,9 +1700,9 @@ func setupMockDB(b *databaseBackend) *mockNewDatabase {
16991700
dbi := &dbPluginInstance{
17001701
database: dbw,
17011702
id: "foo-id",
1702-
name: "mockV5",
1703+
name: mockv5,
17031704
}
1704-
b.connections.Put("mockv5", dbi)
1705+
b.connections.Put(mockv5, dbi)
17051706

17061707
return mockDB
17071708
}
@@ -1710,7 +1711,7 @@ func setupMockDB(b *databaseBackend) *mockNewDatabase {
17101711
// plugin init code paths, allowing us to use a manually populated mock DB object.
17111712
func configureDBMount(t *testing.T, storage logical.Storage) {
17121713
t.Helper()
1713-
entry, err := logical.StorageEntryJSON(fmt.Sprintf("config/mockv5"), &DatabaseConfig{
1714+
entry, err := logical.StorageEntryJSON(fmt.Sprintf("config/"+mockv5), &DatabaseConfig{
17141715
AllowedRoles: []string{"*"},
17151716
})
17161717
if err != nil {

changelog/26147.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
```release-note:bug
2+
secret/database: Fixed race condition where database mounts may leak connections
3+
```

helper/syncmap/syncmap.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,20 @@ func (m *SyncMap[K, V]) Put(k K, v V) V {
6262
return oldV
6363
}
6464

65+
// PutIfEmpty adds the given key-value pair to the map only if there is no value already in it,
66+
// and returns the new value and true if so.
67+
// If there is already a value, it returns the existing value and false.
68+
func (m *SyncMap[K, V]) PutIfEmpty(k K, v V) (V, bool) {
69+
m.lock.Lock()
70+
defer m.lock.Unlock()
71+
oldV, ok := m.data[k]
72+
if ok {
73+
return oldV, false
74+
}
75+
m.data[k] = v
76+
return v, true
77+
}
78+
6579
// Clear deletes all entries from the map, and returns the previous map.
6680
func (m *SyncMap[K, V]) Clear() map[K]V {
6781
m.lock.Lock()

0 commit comments

Comments
 (0)