Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 77 additions & 30 deletions dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,8 @@ func (r *fakeResolver) Resolve(_ context.Context, name string) (instance.ConnNam
func TestDialerSuccessfullyDialsDnsTxtRecord(t *testing.T) {
inst := mock.NewFakeCSQLInstance(
"my-project", "my-region", "my-instance",
mock.WithDNSMapping("db.example.com", "INSTANCE", "CUSTOM_SAN"),
mock.WithDNSMapping("db2.example.com", "INSTANCE", "CUSTOM_SAN"),
)
wantName, _ := instance.ParseConnNameWithDomainName("my-project:my-region:my-instance", "db.example.com")
wantName2, _ := instance.ParseConnNameWithDomainName("my-project:my-region:my-instance", "db2.example.com")
Expand Down Expand Up @@ -1046,9 +1048,11 @@ func TestDialerUpdatesAutomaticallyAfterDnsChange(t *testing.T) {
// SRV record and connect to the correct instance.
inst := mock.NewFakeCSQLInstance(
"my-project", "my-region", "my-instance",
mock.WithDNS("update.example.com"),
)
inst2 := mock.NewFakeCSQLInstance(
"my-project", "my-region", "my-instance2",
mock.WithDNS("update.example.com"),
)
r := &changingResolver{
stage: new(int32),
Expand Down Expand Up @@ -1104,42 +1108,85 @@ func TestDialerUpdatesAutomaticallyAfterDnsChange(t *testing.T) {

func TestDialerChecksSubjectAlternativeNameAndSucceeds(t *testing.T) {

// Create an instance with custom SAN 'db.example.com'
inst := mock.NewFakeCSQLInstanceWithSan(
"my-project", "my-region", "my-instance", []string{"db.example.com"},
mock.WithDNS("db.example.com"),
mock.WithServerCAMode("GOOGLE_MANAGED_CAS_CA"),
)
tcs := []struct {
name string
legacy bool
icn string
dn string
}{{
name: "domainName DnsName older",
legacy: true,
icn: "my-project:my-region:my-instance",
}, {
name: "domainName DnsNames newer",
legacy: false,
icn: "my-project:my-region:my-instance",
},
{
name: "InstanceConnectionName DnsName older",
legacy: true,
icn: "my-project:my-region:my-instance",
dn: "db.example.com",
}, {
name: "InstanceConnectionName DnsNames newer",
legacy: false,
icn: "my-project:my-region:my-instance",
dn: "db.example.com",
}}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
// Create an instance with custom SAN 'db.example.com'
var inst mock.FakeCSQLInstance
if tc.legacy || tc.dn == "" {
inst = mock.NewFakeCSQLInstance(
"my-project", "my-region", "my-instance",
mock.WithDNS("db.example.com"),
mock.WithServerCAMode("GOOGLE_MANAGED_CAS_CA"),
)
} else {
inst = mock.NewFakeCSQLInstance(
"my-project", "my-region", "my-instance",
mock.WithDNSMapping("db.example.com", "INSTANCE", "CUSTOM_SAN"),
mock.WithServerCAMode("GOOGLE_MANAGED_CAS_CA"),
)
}

wantName, _ := instance.ParseConnNameWithDomainName("my-project:my-region:my-instance", "db.example.com")
d := setupDialer(t, setupConfig{
testInstance: inst,
reqs: []*mock.Request{
mock.InstanceGetSuccess(inst, 1),
mock.CreateEphemeralSuccess(inst, 1),
},
dialerOptions: []Option{
WithTokenSource(mock.EmptyTokenSource{}),
WithResolver(&fakeResolver{
entries: map[string]instance.ConnName{
"db.example.com": wantName,
wantName, _ := instance.ParseConnNameWithDomainName(tc.icn, tc.dn)
d := setupDialer(t, setupConfig{
testInstance: inst,
reqs: []*mock.Request{
mock.InstanceGetSuccess(inst, 1),
mock.CreateEphemeralSuccess(inst, 1),
},
}),
},
})
dialerOptions: []Option{
WithTokenSource(mock.EmptyTokenSource{}),
WithResolver(&fakeResolver{
entries: map[string]instance.ConnName{
"db.example.com": wantName,
"my-project:my-region:my-instance": wantName,
},
}),
},
})
dnOrIcn := tc.icn
if tc.dn != "" {
dnOrIcn = tc.dn
}

// Dial db.example.com
testSuccessfulDial(
context.Background(), t, d,
"db.example.com",
)
// Dial db.example.com
testSuccessfulDial(
context.Background(), t, d,
dnOrIcn,
)
})
}
}

func TestDialerChecksSubjectAlternativeNameAndFails(t *testing.T) {

// Create an instance with custom SAN 'db.example.com'
inst := mock.NewFakeCSQLInstanceWithSan(
"my-project", "my-region", "my-instance", []string{"db.example.com"},
inst := mock.NewFakeCSQLInstance(
"my-project", "my-region", "my-instance",
mock.WithDNS("db.example.com"),
mock.WithServerCAMode("GOOGLE_MANAGED_CAS_CA"),
)
Expand Down Expand Up @@ -1207,8 +1254,8 @@ func TestDialerRefreshesAfterRotateCACerts(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
inst := mock.NewFakeCSQLInstanceWithSan(
"my-project", "my-region", "my-instance", []string{"db.example.com"},
inst := mock.NewFakeCSQLInstance(
"my-project", "my-region", "my-instance",
mock.WithDNS("db.example.com"),
mock.WithServerCAMode("GOOGLE_MANAGED_CAS_CA"),
)
Expand Down
21 changes: 0 additions & 21 deletions e2e_mysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,12 @@
package cloudsqlconn_test

import (
"context"
"database/sql"
"fmt"
"os"
"testing"
"time"

"cloud.google.com/go/cloudsqlconn"
"cloud.google.com/go/cloudsqlconn/instance"
"cloud.google.com/go/cloudsqlconn/mysql/mysql"
gomysql "github.com/go-sql-driver/mysql"
)
Expand Down Expand Up @@ -55,16 +52,6 @@ func requireMySQLVars(t *testing.T) {
}
}

type mockResolver struct {
}

func (r *mockResolver) Resolve(_ context.Context, name string) (instanceName instance.ConnName, err error) {
if name == "mysql.example.com" {
return instance.ParseConnNameWithDomainName(mysqlConnName, "mysql.example.com")
}
return instance.ConnName{}, fmt.Errorf("no resolution for %v", name)
}

func TestMySQLDriver(t *testing.T) {
if testing.Short() {
t.Skip("skipping MySQL integration tests")
Expand Down Expand Up @@ -94,14 +81,6 @@ func TestMySQLDriver(t *testing.T) {
user: mysqlIAMUser,
password: "password",
},
{
desc: "with dns",
driverName: "cloudsql-mysql-dns",
opts: []cloudsqlconn.Option{cloudsqlconn.WithResolver(&mockResolver{})},
instanceName: "mysql.example.com",
user: mysqlUser,
password: mysqlPass,
},
}

for _, tc := range tcs {
Expand Down
6 changes: 2 additions & 4 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ module cloud.google.com/go/cloudsqlconn

go 1.23.0

toolchain go1.23.7

require (
cloud.google.com/go/auth v0.15.0
cloud.google.com/go/auth/oauth2adapt v0.2.7
Expand All @@ -16,8 +14,8 @@ require (
golang.org/x/net v0.37.0
golang.org/x/oauth2 v0.28.0
golang.org/x/time v0.11.0
google.golang.org/api v0.224.0
google.golang.org/genproto/googleapis/rpc v0.0.0-20250227231956-55c901821b1e
google.golang.org/api v0.225.0
google.golang.org/genproto/googleapis/rpc v0.0.0-20250303144028-a0af3efb3deb
google.golang.org/grpc v1.71.0
)

Expand Down
8 changes: 4 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -305,17 +305,17 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/api v0.224.0 h1:Ir4UPtDsNiwIOHdExr3fAj4xZ42QjK7uQte3lORLJwU=
google.golang.org/api v0.224.0/go.mod h1:3V39my2xAGkodXy0vEqcEtkqgw2GtrFL5WuBZlCTCOQ=
google.golang.org/api v0.225.0 h1:+4/IVqBQm0MV5S+JW3kdEGC1WtOmM2mXN1LKH1LdNlw=
google.golang.org/api v0.225.0/go.mod h1:WP/0Xm4LVvMOCldfvOISnWquSRWbG2kArDZcg+W2DbY=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo=
google.golang.org/genproto/googleapis/api v0.0.0-20250106144421-5f5ef82da422 h1:GVIKPyP/kLIyVOgOnTwFOrvQaQUzOzGMCxgFUOEmm24=
google.golang.org/genproto/googleapis/api v0.0.0-20250106144421-5f5ef82da422/go.mod h1:b6h1vNKhxaSoEI+5jc3PJUCustfli/mRab7295pY7rw=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250227231956-55c901821b1e h1:YA5lmSs3zc/5w+xsRcHqpETkaYyK63ivEPzNTcUUlSA=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250227231956-55c901821b1e/go.mod h1:LuRYeWDFV6WOn90g357N17oMCaxpgCnbi/44qJvDn2I=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250303144028-a0af3efb3deb h1:TLPQVbx1GJ8VKZxz52VAxl1EBgKXXbTiU9Fc5fZeLn4=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250303144028-a0af3efb3deb/go.mod h1:LuRYeWDFV6WOn90g357N17oMCaxpgCnbi/44qJvDn2I=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg=
google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY=
Expand Down
57 changes: 29 additions & 28 deletions internal/cloudsql/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,43 +242,44 @@ func (c ConnectionInfo) TLSConfig() *tls.Config {
pool.AddCert(caCert)
}

// For CAS instances, we can rely on the DNS name to verify the server identity.
if c.ServerCAMode != "" && c.ServerCAMode != "GOOGLE_MANAGED_INTERNAL_CA" {
// By default, use Standard TLS hostname verification name to
// verify the server identity.

// If the connector was configured with a domain name, use that domain name
// to validate the certificate. Otherwise, use the DNS name from the
// instance ConnectionInfo API response.
serverName := c.ConnectionName.DomainName()
if serverName == "" {
serverName = c.DNSName
}

// If the instance metadata does not contain a domain name, use the legacy
// validation checking the CN field for the instance connection name.
if c.DNSName == "" {
return &tls.Config{
ServerName: serverName,
ServerName: c.ConnectionName.String(),
Certificates: []tls.Certificate{c.ClientCertificate},
RootCAs: pool,
MinVersion: tls.VersionTLS13,
// We need to set InsecureSkipVerify to true due to
// https://github.com/GoogleCloudPlatform/cloudsql-proxy/issues/194
// https://tip.golang.org/doc/go1.11#crypto/x509
//
// Since we have a secure channel to the Cloud SQL API which we use to
// retrieve the certificates, we instead need to implement our own
// VerifyPeerCertificate function that will verify that the certificate
// is OK.
InsecureSkipVerify: true,
VerifyPeerCertificate: verifyPeerCertificateFunc(c.ConnectionName, pool),
MinVersion: tls.VersionTLS13,
}
}
// For legacy instances use the custom TLS validation

// If the connector was configured with a domain name, use that domain name
// to validate the certificate. Otherwise, use the DNS name from the
// instance metadata retrieved from the ConnectSettings API endpoint.
serverName := c.ConnectionName.DomainName()
if serverName == "" {
serverName = c.DNSName
}

// By default, use Standard TLS hostname verification name to
// verify the server identity.
return &tls.Config{
ServerName: c.ConnectionName.String(),
ServerName: serverName,
Certificates: []tls.Certificate{c.ClientCertificate},
RootCAs: pool,
// We need to set InsecureSkipVerify to true due to
// https://github.com/GoogleCloudPlatform/cloudsql-proxy/issues/194
// https://tip.golang.org/doc/go1.11#crypto/x509
//
// Since we have a secure channel to the Cloud SQL API which we use to
// retrieve the certificates, we instead need to implement our own
// VerifyPeerCertificate function that will verify that the certificate
// is OK.
InsecureSkipVerify: true,
VerifyPeerCertificate: verifyPeerCertificateFunc(c.ConnectionName, pool),
MinVersion: tls.VersionTLS13,
MinVersion: tls.VersionTLS13,
}

}

// verifyPeerCertificateFunc creates a VerifyPeerCertificate func that
Expand Down
35 changes: 32 additions & 3 deletions internal/cloudsql/refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,26 @@ func fetchMetadata(

// resolve DnsName into IP address for PSC
// Note that we have to check for PSC enablement first because CAS instances also set the DnsName.
if db.PscEnabled && db.DnsName != "" {
ipAddrs[PSC] = db.DnsName
if db.PscEnabled {
// Search the dns_names field for the PSC DNS Name.
pscDNSName := ""
for _, dnm := range db.DnsNames {
if dnm.Name != "" &&
dnm.ConnectionType == "PRIVATE_SERVICE_CONNECT" && dnm.DnsScope == "INSTANCE" {
Comment on lines +107 to +108
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's add a comment here to explain this logic, especially since DnsScope may not be trivial

pscDNSName = dnm.Name
break
}
}

// If the psc dns name was not found, use the legacy dns_name field
if pscDNSName == "" && db.DnsName != "" {
pscDNSName = db.DnsName
}

// If the psc dns name was found, add it to the ipaddrs map.
if pscDNSName != "" {
ipAddrs[PSC] = pscDNSName
}
}

if len(ipAddrs) == 0 {
Expand All @@ -128,11 +146,22 @@ func fetchMetadata(
caCerts = append(caCerts, caCert)
}

// Find a DNS name to use to validate the certificate from the dns_names field. Any
// name in the list may be used to validate the server TLS certificate.
// Fall back to legacy dns_name field if necessary.
var serverName string
if len(db.DnsNames) > 0 {
serverName = db.DnsNames[0].Name
}
if serverName == "" {
serverName = db.DnsName
}

m = metadata{
ipAddrs: ipAddrs,
serverCACert: caCerts,
version: db.DatabaseVersion,
dnsName: db.DnsName,
dnsName: serverName,
serverCAMode: db.ServerCaMode,
}

Expand Down
4 changes: 2 additions & 2 deletions internal/mock/certs.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,9 @@ func (ct *TLSCertificates) generateServerCertWithCn(cn string) *x509.Certificate
// serverChain creates a []tls.Certificate for use with a TLS server socket.
// serverCAMode controls whether this returns a legacy or CAS server
// certificate.
func (ct *TLSCertificates) serverChain(serverCAMode string) []tls.Certificate {
func (ct *TLSCertificates) serverChain(useStandardTLSValidation bool) []tls.Certificate {
// if this server is running in legacy mode
if serverCAMode == "" || serverCAMode == "GOOGLE_MANAGED_INTERNAL_CA" {
if !useStandardTLSValidation {
return []tls.Certificate{{
Certificate: [][]byte{ct.serverCert.Raw, ct.serverCaCert.Raw},
PrivateKey: ct.serverKey,
Expand Down
Loading
Loading