Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 52 additions & 26 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
trace.AddDialerID(d.dialerID),
)
defer func() {
go trace.RecordDialError(context.Background(), icn, d.dialerID, err)
trace.RecordDialError(context.Background(), icn, d.dialerID, err)
endDial(err)
}()
cn, err := d.resolver.Resolve(ctx, icn)
Expand Down Expand Up @@ -429,14 +429,12 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
}

latency := time.Since(startTime).Milliseconds()
go func() {
n := atomic.AddUint64(c.openConnsCount, 1)
trace.RecordOpenConnections(ctx, int64(n), d.dialerID, cn.String())
trace.RecordDialLatency(ctx, icn, d.dialerID, latency)
}()
n := c.openConnsCount.Add(1)
trace.RecordOpenConnections(ctx, int64(n), d.dialerID, cn.String())
trace.RecordDialLatency(ctx, icn, d.dialerID, latency)

closeFunc := func() {
n := atomic.AddUint64(c.openConnsCount, ^uint64(0)) // c.openConnsCount = c.openConnsCount - 1
n := c.openConnsCount.Add(^uint64(0)) // c.openConnsCount = c.openConnsCount - 1
trace.RecordOpenConnections(context.Background(), int64(n), d.dialerID, cn.String())
}
errFunc := func(err error) {
Expand Down Expand Up @@ -571,33 +569,44 @@ func (d *Dialer) Warmup(ctx context.Context, icn string, opts ...DialOption) err
// newInstrumentedConn initializes an instrumentedConn that on closing will
// decrement the number of open connects and record the result.
func newInstrumentedConn(conn net.Conn, closeFunc func(), errFunc func(error), dialerID, connName string) *instrumentedConn {
return &instrumentedConn{
Conn: conn,
closeFunc: closeFunc,
errFunc: errFunc,
dialerID: dialerID,
connName: connName,
ctx, cancel := context.WithCancel(context.Background())
c := &instrumentedConn{
Conn: conn,
closeFunc: closeFunc,
errFunc: errFunc,
dialerID: dialerID,
connName: connName,
reportTicker: time.NewTicker(5 * time.Second),
stopReporter: cancel,
}

go c.report(ctx)

return c
}

// instrumentedConn wraps a net.Conn and invokes closeFunc when the connection
// is closed.
type instrumentedConn struct {
net.Conn
closeFunc func()
errFunc func(error)
mu sync.RWMutex
closed bool
dialerID string
connName string
closeFunc func()
errFunc func(error)
mu sync.RWMutex
closed bool
dialerID string
connName string
bytesRead atomic.Int64
bytesWritten atomic.Int64
reportTicker *time.Ticker
stopReporter func()
}

// Read delegates to the underlying net.Conn interface and records number of
// bytes read
func (i *instrumentedConn) Read(b []byte) (int, error) {
bytesRead, err := i.Conn.Read(b)
if err == nil {
go trace.RecordBytesReceived(context.Background(), int64(bytesRead), i.connName, i.dialerID)
i.bytesRead.Add(int64(bytesRead))
} else {
i.errFunc(err)
}
Expand All @@ -609,7 +618,7 @@ func (i *instrumentedConn) Read(b []byte) (int, error) {
func (i *instrumentedConn) Write(b []byte) (int, error) {
bytesWritten, err := i.Conn.Write(b)
if err == nil {
go trace.RecordBytesSent(context.Background(), int64(bytesWritten), i.connName, i.dialerID)
i.bytesWritten.Add(int64(bytesWritten))
} else {
i.errFunc(err)
}
Expand All @@ -629,12 +638,29 @@ func (i *instrumentedConn) Close() error {
i.mu.Lock()
defer i.mu.Unlock()
i.closed = true
err := i.Conn.Close()
if err != nil {
return err
i.stopReporter()
i.reportCounters()
i.closeFunc()
return i.Conn.Close()
}

func (i *instrumentedConn) reportCounters() {
bytesRead := i.bytesRead.Swap(0)
bytesWritten := i.bytesWritten.Swap(0)
trace.RecordBytesReceived(context.Background(), bytesRead, i.connName, i.dialerID)
trace.RecordBytesSent(context.Background(), bytesWritten, i.connName, i.dialerID)
}

func (i *instrumentedConn) report(ctx context.Context) {
defer i.reportTicker.Stop()
for {
select {
case <-i.reportTicker.C:
i.reportCounters()
case <-ctx.Done():
return
}
}
go i.closeFunc()
return nil
}

// Close closes the Dialer; it prevents the Dialer from refreshing the information
Expand Down
10 changes: 4 additions & 6 deletions dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1024,13 +1024,13 @@ func TestDialerFailsDnsTxtRecordMissing(t *testing.T) {
}

type changingResolver struct {
stage *int32
stage atomic.Int32
}

func (r *changingResolver) Resolve(_ context.Context, name string) (instance.ConnName, error) {
// For TestDialerFailoverOnInstanceChange
if name == "update.example.com" {
if atomic.LoadInt32(r.stage) == 0 {
if r.stage.Load() == 0 {
return instance.ParseConnNameWithDomainName("my-project:my-region:my-instance", "update.example.com")
}
return instance.ParseConnNameWithDomainName("my-project:my-region:my-instance2", "update.example.com")
Expand All @@ -1054,9 +1054,7 @@ func TestDialerUpdatesAutomaticallyAfterDnsChange(t *testing.T) {
"my-project", "my-region", "my-instance2",
mock.WithDNS("update.example.com"),
)
r := &changingResolver{
stage: new(int32),
}
r := &changingResolver{}

d := setupDialer(t, setupConfig{
skipServer: true,
Expand Down Expand Up @@ -1084,7 +1082,7 @@ func TestDialerUpdatesAutomaticallyAfterDnsChange(t *testing.T) {
"update.example.com",
)
stop1()
atomic.StoreInt32(r.stage, 1)
r.stage.Store(1)

time.Sleep(1 * time.Second)
instCn, _ := instance.ParseConnNameWithDomainName("my-project:my-region:my-instance", "update.example.com")
Expand Down
8 changes: 6 additions & 2 deletions metrics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ func TestDialerWithMetrics(t *testing.T) {
if err != nil {
t.Fatalf("expected Dial to succeed, but got error: %v", err)
}
defer conn.Close()
// dial the good instance again to check the counter
conn2, err := d.Dial(context.Background(), "my-project:my-region:my-instance")
if err != nil {
Expand All @@ -194,7 +193,6 @@ func TestDialerWithMetrics(t *testing.T) {
if err != nil {
t.Fatalf("conn.Write failed: %v", err)
}
defer conn2.Close()
// dial a bogus instance
_, err = d.Dial(context.Background(), "my-project:my-region:notaninstance")
if err == nil {
Expand All @@ -205,6 +203,12 @@ func TestDialerWithMetrics(t *testing.T) {

// success metrics
wantLastValueMetric(t, "cloudsqlconn/open_connections", spy.data(), 2)

conn.Close()
conn2.Close()

time.Sleep(10 * time.Millisecond) // allow exporter a chance to run

wantDistributionMetric(t, "cloudsqlconn/dial_latency", spy.data())
wantCountMetric(t, "cloudsqlconn/refresh_success_count", spy.data())
wantSumMetric(t, "cloudsqlconn/bytes_sent", spy.data())
Expand Down
7 changes: 3 additions & 4 deletions monitored_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import (
// monitoredCache is a wrapper around a connectionInfoCache that tracks the
// number of connections to the associated instance.
type monitoredCache struct {
openConnsCount *uint64
openConnsCount atomic.Uint64
cn instance.ConnName
resolver instance.ConnectionNameResolver
logger debug.ContextLogger
Expand All @@ -53,7 +53,6 @@ func newMonitoredCache(
logger debug.ContextLogger) *monitoredCache {

c := &monitoredCache{
openConnsCount: new(uint64),
closedCh: make(chan struct{}),
cn: cn,
resolver: resolver,
Expand Down Expand Up @@ -98,13 +97,13 @@ func (c *monitoredCache) Close() error {
c.domainNameTicker.Stop()
}

if atomic.LoadUint64(c.openConnsCount) > 0 {
if c.openConnsCount.Load() > 0 {
for _, socket := range c.openConns {
if !socket.isClosed() {
_ = socket.Close() // force socket closed, ok to ignore error.
}
}
atomic.StoreUint64(c.openConnsCount, 0)
c.openConnsCount.Store(0)
}

return c.connectionInfoCache.Close()
Expand Down
39 changes: 19 additions & 20 deletions monitored_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,7 @@ func TestMonitoredCache_purgeClosedConns(t *testing.T) {

func TestMonitoredCache_checkDomainName_instanceChanged(t *testing.T) {
cn, _ := instance.ParseConnNameWithDomainName("my-project:my-region:my-instance", "update.example.com")
r := &changingResolver{
stage: new(int32),
}
r := &changingResolver{}
c := newMonitoredCache(context.TODO(),
&spyConnectionInfoCache{},
cn,
Expand All @@ -81,7 +79,7 @@ func TestMonitoredCache_checkDomainName_instanceChanged(t *testing.T) {
t.Fatal("got cache closed, want cache open")
}
// update the domain name
atomic.StoreInt32(r.stage, 1)
r.stage.Store(1)

// wait for the resolver to run
time.Sleep(100 * time.Millisecond)
Expand All @@ -93,11 +91,9 @@ func TestMonitoredCache_checkDomainName_instanceChanged(t *testing.T) {

func TestMonitoredCache_Close(t *testing.T) {
cn, _ := instance.ParseConnNameWithDomainName("my-project:my-region:my-instance", "update.example.com")
var closeFuncCalls int32
var closeFuncCalls atomic.Int32

r := &changingResolver{
stage: new(int32),
}
r := &changingResolver{}

c := newMonitoredCache(context.TODO(),
&spyConnectionInfoCache{},
Expand All @@ -107,29 +103,32 @@ func TestMonitoredCache_Close(t *testing.T) {
&testLog{t: t},
)
inc := func() {
atomic.AddInt32(&closeFuncCalls, 1)
closeFuncCalls.Add(1)
}

c.mu.Lock()
// set up the state as if there were 2 open connections.
c.openConns = []*instrumentedConn{
{
closed: false,
closeFunc: inc,
Conn: &mockConn{},
closed: false,
closeFunc: inc,
stopReporter: func() {},
Conn: &mockConn{},
},
{
closed: false,
closeFunc: inc,
Conn: &mockConn{},
closed: false,
closeFunc: inc,
stopReporter: func() {},
Conn: &mockConn{},
},
{
closed: true,
closeFunc: inc,
Conn: &mockConn{},
closed: true,
closeFunc: inc,
stopReporter: func() {},
Conn: &mockConn{},
},
}
*c.openConnsCount = 2
c.openConnsCount.Store(2)
c.mu.Unlock()

c.Close()
Expand All @@ -138,7 +137,7 @@ func TestMonitoredCache_Close(t *testing.T) {
}
// wait for closeFunc() to be called.
time.Sleep(100 * time.Millisecond)
if got := atomic.LoadInt32(&closeFuncCalls); got != 2 {
if got := closeFuncCalls.Load(); got != 2 {
t.Fatalf("got %d, want 2", got)
}

Expand Down
Loading