diff --git a/internal/cloudsql/instance.go b/internal/cloudsql/instance.go index 6b797506..60e4e171 100644 --- a/internal/cloudsql/instance.go +++ b/internal/cloudsql/instance.go @@ -293,9 +293,17 @@ func (i *Instance) scheduleRefresh(d time.Duration) *refreshOperation { res.md, res.tlsCfg, res.expiry, res.err = i.r.performRefresh(i.ctx, i.connName, i.key, i.RefreshCfg.UseIAMAuthN) close(res.ready) + select { + case <-i.ctx.Done(): + // instance has been closed, don't schedule anything + return + default: + } + // Once the refresh is complete, update "current" with working result and schedule a new refresh i.resultGuard.Lock() defer i.resultGuard.Unlock() + // if failed, scheduled the next refresh immediately if res.err != nil { i.next = i.scheduleRefresh(0) @@ -308,14 +316,9 @@ func (i *Instance) scheduleRefresh(d time.Duration) *refreshOperation { } return } + // Update the current results, and schedule the next refresh in the future i.cur = res - select { - case <-i.ctx.Done(): - // instance has been closed, don't schedule anything - return - default: - } t := refreshDuration(time.Now(), i.cur.expiry) i.next = i.scheduleRefresh(t) }) diff --git a/internal/cloudsql/instance_test.go b/internal/cloudsql/instance_test.go index d86913b5..5f2d0a05 100644 --- a/internal/cloudsql/instance_test.go +++ b/internal/cloudsql/instance_test.go @@ -303,3 +303,32 @@ func TestRefreshDuration(t *testing.T) { }) } } + +func TestContextCancelled(t *testing.T) { + ctx := context.Background() + + client, cleanup, err := mock.NewSQLAdminService(ctx) + if err != nil { + t.Fatalf("%s", err) + } + defer cleanup() + + // Set up an instance and then close it immediately + im, err := NewInstance("my-proj:my-region:my-inst", client, RSAKey, 30, nil, "", RefreshCfg{}) + if err != nil { + t.Fatalf("failed to initialize Instance: %v", err) + } + im.Close() + + // grab the current value of next before scheduling another refresh + next := im.next + + op := im.scheduleRefresh(time.Nanosecond) + <-op.ready + + // if scheduleRefresh returns without scheduling another one, + // i.next should be untouched and remain the same pointer value + if im.next != next { + t.Fatalf("refresh did not return after a closed context. next pointer changed: want = %p, got = %p", next, im.next) + } +}