Skip to content

Commit 533eb4e

Browse files
authored
feat: use handshake context when possible (#199)
This is a port of GoogleCloudPlatform/cloud-sql-go-connector#427.
1 parent c519ba8 commit 533eb4e

File tree

3 files changed

+89
-7
lines changed

3 files changed

+89
-7
lines changed

connect_tls_117.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// Copyright 2023 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
//go:build go1.17
16+
// +build go1.17
17+
18+
package alloydbconn
19+
20+
import (
21+
"context"
22+
"crypto/tls"
23+
"net"
24+
25+
"cloud.google.com/go/alloydbconn/errtype"
26+
"cloud.google.com/go/alloydbconn/internal/alloydb"
27+
)
28+
29+
// connectTLS returns a new TLS client side connection
30+
// using conn as the underlying transport.
31+
//
32+
// The returned connection has already completed its TLS handshake.
33+
func connectTLS(ctx context.Context, conn net.Conn, c *tls.Config, i *alloydb.Instance) (net.Conn, error) {
34+
tlsConn := tls.Client(conn, c)
35+
// HandshakeContext was introduced in Go 1.17, hence
36+
// this file is conditionally compiled on only Go versions >= 1.17.
37+
if err := tlsConn.HandshakeContext(ctx); err != nil {
38+
// refresh the instance info in case it caused the handshake failure
39+
i.ForceRefresh()
40+
_ = tlsConn.Close() // best effort close attempt
41+
return nil, errtype.NewDialError("handshake failed", i.String(), err)
42+
}
43+
return tlsConn, nil
44+
}

connect_tls_other.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// Copyright 2023 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
//go:build !go1.17
16+
// +build !go1.17
17+
18+
package alloydbconn
19+
20+
import (
21+
"context"
22+
"crypto/tls"
23+
"net"
24+
25+
"cloud.google.com/go/alloydbconn/errtype"
26+
"cloud.google.com/go/alloydbconn/internal/alloydb"
27+
)
28+
29+
// connectTLS returns a new TLS client side connection
30+
// using conn as the underlying transport.
31+
//
32+
// The returned connection has already completed its TLS handshake.
33+
func connectTLS(_ context.Context, conn net.Conn, c *tls.Config, i *alloydb.Instance) (net.Conn, error) {
34+
tlsConn := tls.Client(conn, c)
35+
if err := tlsConn.Handshake(); err != nil {
36+
// refresh the instance info in case it caused the handshake failure
37+
i.ForceRefresh()
38+
_ = tlsConn.Close() // best effort close attempt
39+
return nil, errtype.NewDialError("handshake failed", i.String(), err)
40+
}
41+
return tlsConn, nil
42+
}

dialer.go

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ import (
1818
"context"
1919
"crypto/rand"
2020
"crypto/rsa"
21-
"crypto/tls"
2221
_ "embed"
2322
"fmt"
2423
"net"
@@ -194,12 +193,9 @@ func (d *Dialer) Dial(ctx context.Context, instance string, opts ...DialOption)
194193
return nil, errtype.NewDialError("failed to set keep-alive period", i.String(), err)
195194
}
196195
}
197-
tlsConn := tls.Client(conn, tlsCfg)
198-
if err := tlsConn.Handshake(); err != nil {
199-
// refresh the instance info in case it caused the handshake failure
200-
i.ForceRefresh()
201-
_ = tlsConn.Close() // best effort close attempt
202-
return nil, errtype.NewDialError("handshake failed", i.String(), err)
196+
tlsConn, err := connectTLS(ctx, conn, tlsCfg, i)
197+
if err != nil {
198+
return nil, err
203199
}
204200
latency := time.Since(startTime).Milliseconds()
205201
go func() {

0 commit comments

Comments
 (0)