diff --git a/dialer.go b/dialer.go index 5c6fe7d9..d43d79ad 100644 --- a/dialer.go +++ b/dialer.go @@ -338,7 +338,7 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn trace.AddDialerID(d.dialerID), ) defer func() { - go trace.RecordDialError(context.Background(), icn, d.dialerID, err) + trace.RecordDialError(context.Background(), icn, d.dialerID, err) endDial(err) }() cn, err := d.resolver.Resolve(ctx, icn) @@ -429,14 +429,12 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn } latency := time.Since(startTime).Milliseconds() - go func() { - n := atomic.AddUint64(c.openConnsCount, 1) - trace.RecordOpenConnections(ctx, int64(n), d.dialerID, cn.String()) - trace.RecordDialLatency(ctx, icn, d.dialerID, latency) - }() + n := c.openConnsCount.Add(1) + trace.RecordOpenConnections(ctx, int64(n), d.dialerID, cn.String()) + trace.RecordDialLatency(ctx, icn, d.dialerID, latency) closeFunc := func() { - n := atomic.AddUint64(c.openConnsCount, ^uint64(0)) // c.openConnsCount = c.openConnsCount - 1 + n := c.openConnsCount.Add(^uint64(0)) // c.openConnsCount = c.openConnsCount - 1 trace.RecordOpenConnections(context.Background(), int64(n), d.dialerID, cn.String()) } errFunc := func(err error) { @@ -571,25 +569,36 @@ func (d *Dialer) Warmup(ctx context.Context, icn string, opts ...DialOption) err // newInstrumentedConn initializes an instrumentedConn that on closing will // decrement the number of open connects and record the result. func newInstrumentedConn(conn net.Conn, closeFunc func(), errFunc func(error), dialerID, connName string) *instrumentedConn { - return &instrumentedConn{ - Conn: conn, - closeFunc: closeFunc, - errFunc: errFunc, - dialerID: dialerID, - connName: connName, + ctx, cancel := context.WithCancel(context.Background()) + c := &instrumentedConn{ + Conn: conn, + closeFunc: closeFunc, + errFunc: errFunc, + dialerID: dialerID, + connName: connName, + reportTicker: time.NewTicker(5 * time.Second), + stopReporter: cancel, } + + go c.report(ctx) + + return c } // instrumentedConn wraps a net.Conn and invokes closeFunc when the connection // is closed. type instrumentedConn struct { net.Conn - closeFunc func() - errFunc func(error) - mu sync.RWMutex - closed bool - dialerID string - connName string + closeFunc func() + errFunc func(error) + mu sync.RWMutex + closed bool + dialerID string + connName string + bytesRead atomic.Int64 + bytesWritten atomic.Int64 + reportTicker *time.Ticker + stopReporter func() } // Read delegates to the underlying net.Conn interface and records number of @@ -597,7 +606,7 @@ type instrumentedConn struct { func (i *instrumentedConn) Read(b []byte) (int, error) { bytesRead, err := i.Conn.Read(b) if err == nil { - go trace.RecordBytesReceived(context.Background(), int64(bytesRead), i.connName, i.dialerID) + i.bytesRead.Add(int64(bytesRead)) } else { i.errFunc(err) } @@ -609,7 +618,7 @@ func (i *instrumentedConn) Read(b []byte) (int, error) { func (i *instrumentedConn) Write(b []byte) (int, error) { bytesWritten, err := i.Conn.Write(b) if err == nil { - go trace.RecordBytesSent(context.Background(), int64(bytesWritten), i.connName, i.dialerID) + i.bytesWritten.Add(int64(bytesWritten)) } else { i.errFunc(err) } @@ -629,12 +638,29 @@ func (i *instrumentedConn) Close() error { i.mu.Lock() defer i.mu.Unlock() i.closed = true - err := i.Conn.Close() - if err != nil { - return err + i.stopReporter() + i.reportCounters() + i.closeFunc() + return i.Conn.Close() +} + +func (i *instrumentedConn) reportCounters() { + bytesRead := i.bytesRead.Swap(0) + bytesWritten := i.bytesWritten.Swap(0) + trace.RecordBytesReceived(context.Background(), bytesRead, i.connName, i.dialerID) + trace.RecordBytesSent(context.Background(), bytesWritten, i.connName, i.dialerID) +} + +func (i *instrumentedConn) report(ctx context.Context) { + defer i.reportTicker.Stop() + for { + select { + case <-i.reportTicker.C: + i.reportCounters() + case <-ctx.Done(): + return + } } - go i.closeFunc() - return nil } // Close closes the Dialer; it prevents the Dialer from refreshing the information diff --git a/dialer_test.go b/dialer_test.go index 8f9073ad..86b0ad97 100644 --- a/dialer_test.go +++ b/dialer_test.go @@ -1024,13 +1024,13 @@ func TestDialerFailsDnsTxtRecordMissing(t *testing.T) { } type changingResolver struct { - stage *int32 + stage atomic.Int32 } func (r *changingResolver) Resolve(_ context.Context, name string) (instance.ConnName, error) { // For TestDialerFailoverOnInstanceChange if name == "update.example.com" { - if atomic.LoadInt32(r.stage) == 0 { + if r.stage.Load() == 0 { return instance.ParseConnNameWithDomainName("my-project:my-region:my-instance", "update.example.com") } return instance.ParseConnNameWithDomainName("my-project:my-region:my-instance2", "update.example.com") @@ -1054,9 +1054,7 @@ func TestDialerUpdatesAutomaticallyAfterDnsChange(t *testing.T) { "my-project", "my-region", "my-instance2", mock.WithDNS("update.example.com"), ) - r := &changingResolver{ - stage: new(int32), - } + r := &changingResolver{} d := setupDialer(t, setupConfig{ skipServer: true, @@ -1084,7 +1082,7 @@ func TestDialerUpdatesAutomaticallyAfterDnsChange(t *testing.T) { "update.example.com", ) stop1() - atomic.StoreInt32(r.stage, 1) + r.stage.Store(1) time.Sleep(1 * time.Second) instCn, _ := instance.ParseConnNameWithDomainName("my-project:my-region:my-instance", "update.example.com") diff --git a/metrics_test.go b/metrics_test.go index aeef82a0..ed24bb36 100644 --- a/metrics_test.go +++ b/metrics_test.go @@ -172,7 +172,6 @@ func TestDialerWithMetrics(t *testing.T) { if err != nil { t.Fatalf("expected Dial to succeed, but got error: %v", err) } - defer conn.Close() // dial the good instance again to check the counter conn2, err := d.Dial(context.Background(), "my-project:my-region:my-instance") if err != nil { @@ -194,7 +193,6 @@ func TestDialerWithMetrics(t *testing.T) { if err != nil { t.Fatalf("conn.Write failed: %v", err) } - defer conn2.Close() // dial a bogus instance _, err = d.Dial(context.Background(), "my-project:my-region:notaninstance") if err == nil { @@ -205,6 +203,12 @@ func TestDialerWithMetrics(t *testing.T) { // success metrics wantLastValueMetric(t, "cloudsqlconn/open_connections", spy.data(), 2) + + conn.Close() + conn2.Close() + + time.Sleep(10 * time.Millisecond) // allow exporter a chance to run + wantDistributionMetric(t, "cloudsqlconn/dial_latency", spy.data()) wantCountMetric(t, "cloudsqlconn/refresh_success_count", spy.data()) wantSumMetric(t, "cloudsqlconn/bytes_sent", spy.data()) diff --git a/monitored_cache.go b/monitored_cache.go index b3929b53..41f4b5ed 100644 --- a/monitored_cache.go +++ b/monitored_cache.go @@ -27,7 +27,7 @@ import ( // monitoredCache is a wrapper around a connectionInfoCache that tracks the // number of connections to the associated instance. type monitoredCache struct { - openConnsCount *uint64 + openConnsCount atomic.Uint64 cn instance.ConnName resolver instance.ConnectionNameResolver logger debug.ContextLogger @@ -53,7 +53,6 @@ func newMonitoredCache( logger debug.ContextLogger) *monitoredCache { c := &monitoredCache{ - openConnsCount: new(uint64), closedCh: make(chan struct{}), cn: cn, resolver: resolver, @@ -98,13 +97,13 @@ func (c *monitoredCache) Close() error { c.domainNameTicker.Stop() } - if atomic.LoadUint64(c.openConnsCount) > 0 { + if c.openConnsCount.Load() > 0 { for _, socket := range c.openConns { if !socket.isClosed() { _ = socket.Close() // force socket closed, ok to ignore error. } } - atomic.StoreUint64(c.openConnsCount, 0) + c.openConnsCount.Store(0) } return c.connectionInfoCache.Close() diff --git a/monitored_cache_test.go b/monitored_cache_test.go index 0fa42a58..662aeb84 100644 --- a/monitored_cache_test.go +++ b/monitored_cache_test.go @@ -63,9 +63,7 @@ func TestMonitoredCache_purgeClosedConns(t *testing.T) { func TestMonitoredCache_checkDomainName_instanceChanged(t *testing.T) { cn, _ := instance.ParseConnNameWithDomainName("my-project:my-region:my-instance", "update.example.com") - r := &changingResolver{ - stage: new(int32), - } + r := &changingResolver{} c := newMonitoredCache(context.TODO(), &spyConnectionInfoCache{}, cn, @@ -81,7 +79,7 @@ func TestMonitoredCache_checkDomainName_instanceChanged(t *testing.T) { t.Fatal("got cache closed, want cache open") } // update the domain name - atomic.StoreInt32(r.stage, 1) + r.stage.Store(1) // wait for the resolver to run time.Sleep(100 * time.Millisecond) @@ -93,11 +91,9 @@ func TestMonitoredCache_checkDomainName_instanceChanged(t *testing.T) { func TestMonitoredCache_Close(t *testing.T) { cn, _ := instance.ParseConnNameWithDomainName("my-project:my-region:my-instance", "update.example.com") - var closeFuncCalls int32 + var closeFuncCalls atomic.Int32 - r := &changingResolver{ - stage: new(int32), - } + r := &changingResolver{} c := newMonitoredCache(context.TODO(), &spyConnectionInfoCache{}, @@ -107,29 +103,32 @@ func TestMonitoredCache_Close(t *testing.T) { &testLog{t: t}, ) inc := func() { - atomic.AddInt32(&closeFuncCalls, 1) + closeFuncCalls.Add(1) } c.mu.Lock() // set up the state as if there were 2 open connections. c.openConns = []*instrumentedConn{ { - closed: false, - closeFunc: inc, - Conn: &mockConn{}, + closed: false, + closeFunc: inc, + stopReporter: func() {}, + Conn: &mockConn{}, }, { - closed: false, - closeFunc: inc, - Conn: &mockConn{}, + closed: false, + closeFunc: inc, + stopReporter: func() {}, + Conn: &mockConn{}, }, { - closed: true, - closeFunc: inc, - Conn: &mockConn{}, + closed: true, + closeFunc: inc, + stopReporter: func() {}, + Conn: &mockConn{}, }, } - *c.openConnsCount = 2 + c.openConnsCount.Store(2) c.mu.Unlock() c.Close() @@ -138,7 +137,7 @@ func TestMonitoredCache_Close(t *testing.T) { } // wait for closeFunc() to be called. time.Sleep(100 * time.Millisecond) - if got := atomic.LoadInt32(&closeFuncCalls); got != 2 { + if got := closeFuncCalls.Load(); got != 2 { t.Fatalf("got %d, want 2", got) }