Skip to content

Commit ae01193

Browse files
committed
fix: update instance object rather than replace
1 parent 3815230 commit ae01193

File tree

4 files changed

+52
-36
lines changed

4 files changed

+52
-36
lines changed

dialer.go

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ func (d *Dialer) Dial(ctx context.Context, instance string, opts ...DialOption)
186186

187187
var endInfo trace.EndSpanFunc
188188
ctx, endInfo = trace.StartSpan(ctx, "cloud.google.com/go/cloudsqlconn/internal.InstanceInfo")
189-
i, err := d.instance(instance, cfg.refreshCfg)
189+
i, err := d.instance(instance, &cfg.refreshCfg)
190190
if err != nil {
191191
endInfo(err)
192192
return nil, err
@@ -240,7 +240,7 @@ func (d *Dialer) Dial(ctx context.Context, instance string, opts ...DialOption)
240240
// corespond to one of the following types for the instance:
241241
// https://cloud.google.com/sql/docs/mysql/admin-api/rest/v1beta4/SqlDatabaseVersion
242242
func (d *Dialer) EngineVersion(ctx context.Context, instance string) (string, error) {
243-
i, err := d.instance(instance, d.defaultDialCfg.refreshCfg)
243+
i, err := d.instance(instance, nil)
244244
if err != nil {
245245
return "", err
246246
}
@@ -258,7 +258,7 @@ func (d *Dialer) Warmup(ctx context.Context, instance string, opts ...DialOption
258258
for _, opt := range opts {
259259
opt(&cfg)
260260
}
261-
_, err := d.instance(instance, d.defaultDialCfg.refreshCfg)
261+
_, err := d.instance(instance, &d.defaultDialCfg.refreshCfg)
262262
return err
263263
}
264264

@@ -301,30 +301,33 @@ func (d *Dialer) Close() error {
301301
return nil
302302
}
303303

304-
func (d *Dialer) instance(connName string, rCfg cloudsql.RefreshCfg) (*cloudsql.Instance, error) {
304+
// instance is a helper function for returning the appropriate instance object in a threadsafe way.
305+
// It will create a new instance object, modify the existing one, or leave it unchanged as needed.
306+
func (d *Dialer) instance(connName string, rCfg *cloudsql.RefreshCfg) (*cloudsql.Instance, error) {
305307
// Check instance cache
306308
d.lock.RLock()
307309
i, ok := d.instances[connName]
308310
d.lock.RUnlock()
309311
// Check if the instance exists and that the refresh cfg is the same
310-
if !ok || rCfg != i.RefreshCfg {
312+
if !ok || (rCfg != nil && *rCfg != i.RefreshCfg) {
311313
d.lock.Lock()
312314
// Recheck to ensure instance wasn't created or changed between locks
313315
i, ok = d.instances[connName]
314-
if !ok || rCfg != i.RefreshCfg {
316+
if !ok {
315317
// Create a new instance
318+
if rCfg == nil {
319+
rCfg = &d.defaultDialCfg.refreshCfg
320+
}
316321
var err error
317-
newI, err := cloudsql.NewInstance(connName, d.sqladmin, d.key, d.refreshTimeout, d.iamTokenSource, d.dialerID, rCfg)
322+
i, err = cloudsql.NewInstance(connName, d.sqladmin, d.key, d.refreshTimeout, d.iamTokenSource, d.dialerID, *rCfg)
318323
if err != nil {
319324
d.lock.Unlock()
320325
return nil, err
321326
}
322-
// If we created a new instance to match RefreshCfg, close the old one
323-
if ok {
324-
defer i.Close()
325-
}
326-
d.instances[connName] = newI
327-
i = newI
327+
d.instances[connName] = i
328+
} else if rCfg != nil && *rCfg != i.RefreshCfg {
329+
// Update the instance with the new refresh cfg
330+
i.UpdateRefresh(*rCfg)
328331
}
329332
d.lock.Unlock()
330333
}

internal/cloudsql/instance.go

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -128,11 +128,11 @@ type Instance struct {
128128
OpenConns uint64
129129

130130
connName
131-
key *rsa.PrivateKey
132-
r refresher
133-
RefreshCfg RefreshCfg
131+
key *rsa.PrivateKey
134132

135133
resultGuard sync.RWMutex
134+
r refresher
135+
RefreshCfg RefreshCfg
136136
// cur represents the current refreshOperation that will be used to create connections. If a valid complete
137137
// refreshOperation isn't available it's possible for cur to be equal to next.
138138
cur *refreshOperation
@@ -161,10 +161,6 @@ func NewInstance(
161161
return nil, err
162162
}
163163
ctx, cancel := context.WithCancel(context.Background())
164-
var iamTs oauth2.TokenSource
165-
if rCfg.UseIAMAuthN {
166-
iamTs = ts
167-
}
168164
i := &Instance{
169165
connName: cn,
170166
key: key,
@@ -173,7 +169,7 @@ func NewInstance(
173169
30*time.Second,
174170
2,
175171
client,
176-
iamTs,
172+
ts,
177173
dialerID,
178174
),
179175
ctx: ctx,
@@ -224,6 +220,19 @@ func (i *Instance) InstanceEngineVersion(ctx context.Context) (string, error) {
224220
return res.md.version, nil
225221
}
226222

223+
func (i *Instance) UpdateRefresh(cfg RefreshCfg) {
224+
i.resultGuard.Lock()
225+
// Cancel any pending refreshes
226+
i.cur.Cancel()
227+
i.next.Cancel()
228+
// update the refreshcfg as needed
229+
i.RefreshCfg = cfg
230+
// reschedule a new refresh immiediately
231+
i.cur = i.scheduleRefresh(0)
232+
i.next = i.cur
233+
i.resultGuard.Unlock()
234+
}
235+
227236
// ForceRefresh triggers an immediate refresh operation to be scheduled and used for future connection attempts.
228237
func (i *Instance) ForceRefresh() {
229238
i.resultGuard.Lock()
@@ -254,7 +263,7 @@ func (i *Instance) scheduleRefresh(d time.Duration) *refreshOperation {
254263
res := &refreshOperation{}
255264
res.ready = make(chan struct{})
256265
res.timer = time.AfterFunc(d, func() {
257-
res.md, res.tlsCfg, res.expiry, res.err = i.r.performRefresh(i.ctx, i.connName, i.key)
266+
res.md, res.tlsCfg, res.expiry, res.err = i.r.performRefresh(i.ctx, i.connName, i.key, i.RefreshCfg.UseIAMAuthN)
258267
close(res.ready)
259268

260269
// Once the refresh is complete, update "current" with working result and schedule a new refresh

internal/cloudsql/refresh.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ func newRefresher(
270270
dialerID: dialerID,
271271
clientLimiter: rate.NewLimiter(rate.Every(interval), burst),
272272
client: svc,
273-
iamTS: ts,
273+
ts: ts,
274274
}
275275
}
276276

@@ -286,12 +286,12 @@ type refresher struct {
286286
clientLimiter *rate.Limiter
287287
client *sqladmin.Service
288288

289-
// iamTS is the TokenSource used for IAM DB AuthN. It is only set if IAM DB AuthN is being used.
290-
iamTS oauth2.TokenSource
289+
// ts is the TokenSource used for IAM DB AuthN.
290+
ts oauth2.TokenSource
291291
}
292292

293293
// performRefresh immediately performs a full refresh operation using the Cloud SQL Admin API.
294-
func (r refresher) performRefresh(ctx context.Context, cn connName, k *rsa.PrivateKey) (md metadata, c *tls.Config, expiry time.Time, err error) {
294+
func (r refresher) performRefresh(ctx context.Context, cn connName, k *rsa.PrivateKey, iamAuthN bool) (md metadata, c *tls.Config, expiry time.Time, err error) {
295295
var refreshEnd trace.EndSpanFunc
296296
ctx, refreshEnd = trace.StartSpan(ctx, "cloud.google.com/go/cloudsqlconn/internal.RefreshConnection",
297297
trace.AddInstanceName(cn.String()),
@@ -337,7 +337,11 @@ func (r refresher) performRefresh(ctx context.Context, cn connName, k *rsa.Priva
337337
ecC := make(chan ecRes, 1)
338338
go func() {
339339
defer close(ecC)
340-
ec, err := fetchEphemeralCert(ctx, r.client, cn, k, r.iamTS)
340+
var iamTS oauth2.TokenSource
341+
if iamAuthN {
342+
iamTS = r.ts
343+
}
344+
ec, err := fetchEphemeralCert(ctx, r.client, cn, k, iamTS)
341345
ecC <- ecRes{ec, err}
342346
}()
343347

internal/cloudsql/refresh_test.go

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ func TestRefresh(t *testing.T) {
6060
}()
6161

6262
r := newRefresher(time.Hour, 30*time.Second, 2, client, nil, "")
63-
md, tlsCfg, gotExpiry, err := r.performRefresh(context.Background(), cn, RSAKey)
63+
md, tlsCfg, gotExpiry, err := r.performRefresh(context.Background(), cn, RSAKey, false)
6464
if err != nil {
6565
t.Fatalf("PerformRefresh unexpectedly failed with error: %v", err)
6666
}
@@ -101,22 +101,22 @@ func TestRefreshFailsFast(t *testing.T) {
101101
defer cleanup()
102102

103103
r := newRefresher(time.Hour, 30*time.Second, 1, client, nil, "")
104-
_, _, _, err = r.performRefresh(context.Background(), cn, RSAKey)
104+
_, _, _, err = r.performRefresh(context.Background(), cn, RSAKey, false)
105105
if err != nil {
106106
t.Fatalf("expected no error, got = %v", err)
107107
}
108108

109109
ctx, cancel := context.WithCancel(context.Background())
110110
cancel()
111111
// context is canceled
112-
_, _, _, err = r.performRefresh(ctx, cn, RSAKey)
112+
_, _, _, err = r.performRefresh(ctx, cn, RSAKey, false)
113113
if !errors.Is(err, context.Canceled) {
114114
t.Fatalf("expected context.Canceled error, got = %v", err)
115115
}
116116

117117
// force the rate limiter to throttle with a timed out context
118118
ctx, _ = context.WithTimeout(context.Background(), time.Millisecond)
119-
_, _, _, err = r.performRefresh(ctx, cn, RSAKey)
119+
_, _, _, err = r.performRefresh(ctx, cn, RSAKey, false)
120120

121121
var wantErr *errtype.DialError
122122
if !errors.As(err, &wantErr) {
@@ -186,7 +186,7 @@ func TestRefreshAdjustsCertExpiry(t *testing.T) {
186186
t.Run(tc.desc, func(t *testing.T) {
187187
ts := &fakeTokenSource{responses: tc.resps}
188188
r := newRefresher(time.Hour, 30*time.Second, 1, client, ts, "")
189-
_, _, gotExpiry, err := r.performRefresh(context.Background(), cn, RSAKey)
189+
_, _, gotExpiry, err := r.performRefresh(context.Background(), cn, RSAKey, true)
190190
if err != nil {
191191
t.Fatalf("want no error, got = %v", err)
192192
}
@@ -232,7 +232,7 @@ func TestRefreshWithIAMAuthErrors(t *testing.T) {
232232
t.Run(tc.desc, func(t *testing.T) {
233233
ts := &fakeTokenSource{responses: tc.resps}
234234
r := newRefresher(time.Hour, 30*time.Second, 1, client, ts, "")
235-
_, _, _, err = r.performRefresh(context.Background(), cn, RSAKey)
235+
_, _, _, err = r.performRefresh(context.Background(), cn, RSAKey, true)
236236
if err == nil {
237237
t.Fatalf("expected get failed error, got = %v", err)
238238
}
@@ -326,7 +326,7 @@ func TestRefreshWithFailedMetadataCall(t *testing.T) {
326326
defer cleanup()
327327

328328
r := newRefresher(time.Hour, 30*time.Second, 1, client, nil, "")
329-
_, _, _, err = r.performRefresh(context.Background(), cn, RSAKey)
329+
_, _, _, err = r.performRefresh(context.Background(), cn, RSAKey, false)
330330

331331
if !errors.As(err, &tc.wantErr) {
332332
t.Errorf("[%v] PerformRefresh failed with unexpected error, want = %T, got = %v", i, tc.wantErr, err)
@@ -392,7 +392,7 @@ func TestRefreshWithFailedEphemeralCertCall(t *testing.T) {
392392
defer cleanup()
393393

394394
r := newRefresher(time.Hour, 30*time.Second, 1, client, nil, "")
395-
_, _, _, err = r.performRefresh(context.Background(), cn, RSAKey)
395+
_, _, _, err = r.performRefresh(context.Background(), cn, RSAKey, false)
396396

397397
if !errors.As(err, &tc.wantErr) {
398398
t.Errorf("[%v] PerformRefresh failed with unexpected error, want = %T, got = %v", i, tc.wantErr, err)
@@ -419,7 +419,7 @@ func TestRefreshBuildsTLSConfig(t *testing.T) {
419419
defer cleanup()
420420

421421
r := newRefresher(time.Hour, 30*time.Second, 1, client, nil, "")
422-
_, tlsCfg, _, err := r.performRefresh(context.Background(), cn, RSAKey)
422+
_, tlsCfg, _, err := r.performRefresh(context.Background(), cn, RSAKey, false)
423423
if err != nil {
424424
t.Fatalf("expected no error, got = %v", err)
425425
}

0 commit comments

Comments
 (0)