Skip to content

Commit 37c4e70

Browse files
authored
fix: use handshake context when possible (#427)
Fixes #409.
1 parent 52b7633 commit 37c4e70

File tree

3 files changed

+92
-10
lines changed

3 files changed

+92
-10
lines changed

connect_tls_117.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// Copyright 2023 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
//go:build go1.17
16+
// +build go1.17
17+
18+
package cloudsqlconn
19+
20+
import (
21+
"context"
22+
"crypto/tls"
23+
"net"
24+
25+
"cloud.google.com/go/cloudsqlconn/errtype"
26+
"cloud.google.com/go/cloudsqlconn/internal/cloudsql"
27+
)
28+
29+
// connectTLS returns a new TLS client side connection
30+
// using conn as the underlying transport.
31+
//
32+
// The returned connection has already completed its TLS handshake.
33+
func connectTLS(ctx context.Context, conn net.Conn, c *tls.Config, i *cloudsql.Instance) (net.Conn, error) {
34+
tlsConn := tls.Client(conn, c)
35+
// HandshakeContext was introduced in Go 1.17, hence
36+
// this file is conditionally compiled on only Go versions >= 1.17.
37+
if err := tlsConn.HandshakeContext(ctx); err != nil {
38+
// refresh the instance info in case it caused the handshake failure
39+
i.ForceRefresh()
40+
_ = tlsConn.Close() // best effort close attempt
41+
return nil, errtype.NewDialError("handshake failed", i.String(), err)
42+
}
43+
return tlsConn, nil
44+
}

connect_tls_other.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// Copyright 2023 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
//go:build !go1.17
16+
// +build !go1.17
17+
18+
package cloudsqlconn
19+
20+
import (
21+
"context"
22+
"crypto/tls"
23+
"net"
24+
25+
"cloud.google.com/go/cloudsqlconn/errtype"
26+
"cloud.google.com/go/cloudsqlconn/internal/cloudsql"
27+
)
28+
29+
// connectTLS returns a new TLS client side connection
30+
// using conn as the underlying transport.
31+
//
32+
// The returned connection has already completed its TLS handshake.
33+
func connectTLS(_ context.Context, conn net.Conn, c *tls.Config, i *cloudsql.Instance) (net.Conn, error) {
34+
tlsConn := tls.Client(conn, c)
35+
if err := tlsConn.Handshake(); err != nil {
36+
// refresh the instance info in case it caused the handshake failure
37+
i.ForceRefresh()
38+
_ = tlsConn.Close() // best effort close attempt
39+
return nil, errtype.NewDialError("handshake failed", i.String(), err)
40+
}
41+
return tlsConn, nil
42+
}

dialer.go

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
// Copyright 2020 Google LLC
2-
2+
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
55
// You may obtain a copy of the License at
6-
6+
//
77
// https://www.apache.org/licenses/LICENSE-2.0
8-
8+
//
99
// Unless required by applicable law or agreed to in writing, software
1010
// distributed under the License is distributed on an "AS IS" BASIS,
1111
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -20,7 +20,6 @@ import (
2020
"context"
2121
"crypto/rand"
2222
"crypto/rsa"
23-
"crypto/tls"
2423
_ "embed"
2524
"fmt"
2625
"net"
@@ -224,12 +223,9 @@ func (d *Dialer) Dial(ctx context.Context, instance string, opts ...DialOption)
224223
return nil, errtype.NewDialError("failed to set keep-alive period", i.String(), err)
225224
}
226225
}
227-
tlsConn := tls.Client(conn, tlsCfg)
228-
if err := tlsConn.Handshake(); err != nil {
229-
// refresh the instance info in case it caused the handshake failure
230-
i.ForceRefresh()
231-
_ = tlsConn.Close() // best effort close attempt
232-
return nil, errtype.NewDialError("handshake failed", i.String(), err)
226+
tlsConn, err := connectTLS(ctx, conn, tlsCfg, i)
227+
if err != nil {
228+
return nil, err
233229
}
234230
latency := time.Since(startTime).Milliseconds()
235231
go func() {

0 commit comments

Comments
 (0)