From 982cac350dfce3f5aa2c9f7c9808457cbc27e6aa Mon Sep 17 00:00:00 2001 From: Eno Compton Date: Tue, 14 Jun 2022 12:09:07 -0600 Subject: [PATCH 1/5] fix: support MySQL diver's conn check. Fixes #225. --- dialer.go | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/dialer.go b/dialer.go index 5255725b..eef52514 100644 --- a/dialer.go +++ b/dialer.go @@ -25,6 +25,7 @@ import ( "strings" "sync" "sync/atomic" + "syscall" "time" "cloud.google.com/go/cloudsqlconn/errtype" @@ -230,7 +231,7 @@ func (d *Dialer) Dial(ctx context.Context, instance string, opts ...DialOption) trace.RecordDialLatency(ctx, instance, d.dialerID, latency) }() - return newInstrumentedConn(tlsConn, func() { + return newInstrumentedConn(conn, tlsConn, func() { n := atomic.AddUint64(&i.OpenConns, ^uint64(0)) trace.RecordOpenConnections(context.Background(), int64(n), d.dialerID, i.String()) }), nil @@ -264,9 +265,10 @@ func (d *Dialer) Warmup(ctx context.Context, instance string, opts ...DialOption // newInstrumentedConn initializes an instrumentedConn that on closing will // decrement the number of open connects and record the result. -func newInstrumentedConn(conn net.Conn, closeFunc func()) *instrumentedConn { +func newInstrumentedConn(rawConn, conn net.Conn, closeFunc func()) *instrumentedConn { return &instrumentedConn{ Conn: conn, + rawConn: rawConn, closeFunc: closeFunc, } } @@ -275,9 +277,18 @@ func newInstrumentedConn(conn net.Conn, closeFunc func()) *instrumentedConn { // is closed. type instrumentedConn struct { net.Conn + // rawConn is the underlying net.Conn without TLS + rawConn net.Conn closeFunc func() } +// SyscallConn supports a connection check in the MySQL driver by delegating to +// the underlying non-TLS net.Conn. +func (i *instrumentedConn) SyscallConn() (syscall.RawConn, error) { + sconn := i.rawConn.(syscall.Conn) + return sconn.SyscallConn() +} + // Close delegates to the underylying net.Conn interface and reports the close // to the provided closeFunc only when Close returns no error. func (i *instrumentedConn) Close() error { From 8613caa28d5a5dee10c33b06e0cd450774c1e899 Mon Sep 17 00:00:00 2001 From: Eno Compton Date: Tue, 14 Jun 2022 12:16:51 -0600 Subject: [PATCH 2/5] Don't panic on failed type assertion --- dialer.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/dialer.go b/dialer.go index eef52514..d9644648 100644 --- a/dialer.go +++ b/dialer.go @@ -20,6 +20,7 @@ import ( "crypto/rsa" "crypto/tls" _ "embed" + "errors" "fmt" "net" "strings" @@ -285,7 +286,10 @@ type instrumentedConn struct { // SyscallConn supports a connection check in the MySQL driver by delegating to // the underlying non-TLS net.Conn. func (i *instrumentedConn) SyscallConn() (syscall.RawConn, error) { - sconn := i.rawConn.(syscall.Conn) + sconn, ok := i.rawConn.(syscall.Conn) + if !ok { + return nil, errors.New("connection is not a syscall.Conn") + } return sconn.SyscallConn() } From af0b40168da9fdc83feb62cffa141ed17c8311a4 Mon Sep 17 00:00:00 2001 From: Eno Compton Date: Tue, 14 Jun 2022 12:25:03 -0600 Subject: [PATCH 3/5] Use *tls.Conn.NetConn for raw connection --- dialer.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/dialer.go b/dialer.go index d9644648..2d915a8a 100644 --- a/dialer.go +++ b/dialer.go @@ -286,7 +286,11 @@ type instrumentedConn struct { // SyscallConn supports a connection check in the MySQL driver by delegating to // the underlying non-TLS net.Conn. func (i *instrumentedConn) SyscallConn() (syscall.RawConn, error) { - sconn, ok := i.rawConn.(syscall.Conn) + tlsConn, ok := i.Conn.(*tls.Conn) + if !ok { + return nil, errors.New("connection is not a *tls.Conn") + } + sconn, ok := tlsConn.NetConn().(syscall.Conn) if !ok { return nil, errors.New("connection is not a syscall.Conn") } From 658e297b98288c26e080db8fd0f13bbb961a09e1 Mon Sep 17 00:00:00 2001 From: Eno Compton Date: Tue, 14 Jun 2022 12:31:03 -0600 Subject: [PATCH 4/5] tls.Conn.NetConn is Go 1.18 only --- dialer.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/dialer.go b/dialer.go index 2d915a8a..d9644648 100644 --- a/dialer.go +++ b/dialer.go @@ -286,11 +286,7 @@ type instrumentedConn struct { // SyscallConn supports a connection check in the MySQL driver by delegating to // the underlying non-TLS net.Conn. func (i *instrumentedConn) SyscallConn() (syscall.RawConn, error) { - tlsConn, ok := i.Conn.(*tls.Conn) - if !ok { - return nil, errors.New("connection is not a *tls.Conn") - } - sconn, ok := tlsConn.NetConn().(syscall.Conn) + sconn, ok := i.rawConn.(syscall.Conn) if !ok { return nil, errors.New("connection is not a syscall.Conn") } From 21a7582c5f4a7d5808dc14be1702b9faeb42d379 Mon Sep 17 00:00:00 2001 From: Eno Compton Date: Tue, 14 Jun 2022 12:42:51 -0600 Subject: [PATCH 5/5] Add test to ensure compatability --- dialer_test.go | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/dialer_test.go b/dialer_test.go index 5030efa6..06eb963c 100644 --- a/dialer_test.go +++ b/dialer_test.go @@ -22,6 +22,7 @@ import ( "os" "runtime" "strings" + "syscall" "testing" "time" @@ -75,6 +76,48 @@ func TestDialerCanConnectToInstance(t *testing.T) { testSuccessfulDial(t, d, context.Background(), "my-project:my-region:my-instance", WithPublicIP()) } +func TestDialerConnectionSupportsSyscalls(t *testing.T) { + inst := mock.NewFakeCSQLInstance("my-project", "my-region", "my-instance") + svc, cleanup, err := mock.NewSQLAdminService( + context.Background(), + mock.InstanceGetSuccess(inst, 1), + mock.CreateEphemeralSuccess(inst, 1), + ) + if err != nil { + t.Fatalf("failed to init SQLAdminService: %v", err) + } + stop := mock.StartServerProxy(t, inst) + defer func() { + stop() + if err := cleanup(); err != nil { + t.Fatalf("%v", err) + } + }() + + d, err := NewDialer(context.Background(), + WithDefaultDialOptions(WithPublicIP()), + WithTokenSource(mock.EmptyTokenSource{}), + ) + if err != nil { + t.Fatalf("expected NewDialer to succeed, but got error: %v", err) + } + d.sqladmin = svc + + conn, err := d.Dial(context.Background(), "my-project:my-region:my-instance") + if err != nil { + t.Fatalf("expected Dial to succeed, but got error: %v", err) + } + defer conn.Close() + sconn, ok := conn.(syscall.Conn) + if !ok { + t.Fatalf("expected conn to be a syscall.Conn, but it was not") + } + _, err = sconn.SyscallConn() + if err != nil { + t.Fatalf("expected syscall.RawConn, got error: %v", err) + } +} + func TestDialWithAdminAPIErrors(t *testing.T) { inst := mock.NewFakeCSQLInstance("my-project", "my-region", "my-instance") svc, cleanup, err := mock.NewSQLAdminService(context.Background())