diff --git a/dialer_test.go b/dialer_test.go index d8dbd084..8f9073ad 100644 --- a/dialer_test.go +++ b/dialer_test.go @@ -1222,6 +1222,67 @@ func TestDialerChecksSubjectAlternativeNameAndFails(t *testing.T) { } } +func TestDialerChecksSubjectAlternativeNameAndFallsBackToCN(t *testing.T) { + + // Create an instance with custom SAN 'db.example.com' + inst := mock.NewFakeCSQLInstance( + "myProject", "myRegion", "myInstance", + mock.WithDNS("db.example.com"), + mock.WithMissingSAN("db.example.com"), // don't put db.example.com in the server cert. + ) + + // resolve db.example.com to the same instance + wantName, _ := instance.ParseConnNameWithDomainName("myProject:myRegion:myInstance", "db.example.com") + + d := setupDialer(t, setupConfig{ + testInstance: inst, + reqs: []*mock.Request{ + mock.InstanceGetSuccess(inst, 1), + mock.CreateEphemeralSuccess(inst, 1), + }, + + dialerOptions: []Option{ + WithTokenSource(mock.EmptyTokenSource{}), + WithResolver(&fakeResolver{ + entries: map[string]instance.ConnName{ + "db.example.com": wantName, + "myProject:myRegion:myInstance": wantName, + }, + }), + }, + }) + + tcs := []struct { + desc string + icn string + }{ + { + desc: "Fallback from connect with Instance Connection Name", + icn: "myProject:myRegion:myInstance", + }, + { + desc: "Fallback from connect with configured domain name", + icn: "db.example.com", + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + + // Dial 'db2.example.com'. This succeed overall. + // First the Hostname check will fail because the certificate does not + // contain db2.example.com + // Then the CN field check will succeed, because the instance connection + // name matches. + _, err := d.Dial( + context.Background(), tc.icn, + ) + if err != nil { + t.Fatal("Want no error. Got: ", err) + } + }) + } +} + func TestDialerRefreshesAfterRotateCACerts(t *testing.T) { tcs := []struct { desc string diff --git a/internal/cloudsql/instance.go b/internal/cloudsql/instance.go index 10e32af6..8fa62b47 100644 --- a/internal/cloudsql/instance.go +++ b/internal/cloudsql/instance.go @@ -242,87 +242,28 @@ func (c ConnectionInfo) TLSConfig() *tls.Config { pool.AddCert(caCert) } - // If the instance metadata does not contain a domain name, use the legacy - // validation checking the CN field for the instance connection name. - if c.DNSName == "" { - return &tls.Config{ - ServerName: c.ConnectionName.String(), - Certificates: []tls.Certificate{c.ClientCertificate}, - RootCAs: pool, - // We need to set InsecureSkipVerify to true due to - // https://github.com/GoogleCloudPlatform/cloudsql-proxy/issues/194 - // https://tip.golang.org/doc/go1.11#crypto/x509 - // - // Since we have a secure channel to the Cloud SQL API which we use to - // retrieve the certificates, we instead need to implement our own - // VerifyPeerCertificate function that will verify that the certificate - // is OK. - InsecureSkipVerify: true, - VerifyPeerCertificate: verifyPeerCertificateFunc(c.ConnectionName, pool), - MinVersion: tls.VersionTLS13, - } - } - - // If the connector was configured with a domain name, use that domain name - // to validate the certificate. Otherwise, use the DNS name from the - // instance metadata retrieved from the ConnectSettings API endpoint. - serverName := c.ConnectionName.DomainName() - if serverName == "" { + var serverName string + if c.ConnectionName.HasDomainName() { + // If the connector was configured with a DNS name, use the DNS name from + // the configuration to validate the server certificate. + serverName = c.ConnectionName.DomainName() + } else { + // If the connector was configured with an Instance Connection Name, + // use the DNS name from the instance metadata. serverName = c.DNSName } - // By default, use Standard TLS hostname verification name to - // verify the server identity. return &tls.Config{ ServerName: serverName, Certificates: []tls.Certificate{c.ClientCertificate}, RootCAs: pool, MinVersion: tls.VersionTLS13, - } - -} - -// verifyPeerCertificateFunc creates a VerifyPeerCertificate func that -// verifies that the peer certificate is in the cert pool. We need to define -// our own because CloudSQL instances use the instance name (e.g., -// my-project:my-instance) instead of a valid domain name for the certificate's -// Common Name. -func verifyPeerCertificateFunc( - cn instance.ConnName, pool *x509.CertPool, -) func(rawCerts [][]byte, _ [][]*x509.Certificate) error { - return func(rawCerts [][]byte, _ [][]*x509.Certificate) error { - if len(rawCerts) == 0 { - return errtype.NewDialError( - "no certificate to verify", cn.String(), nil, - ) - } - - cert, err := x509.ParseCertificate(rawCerts[0]) - if err != nil { - return errtype.NewDialError( - "failed to parse X.509 certificate", cn.String(), err, - ) - } - - opts := x509.VerifyOptions{Roots: pool} - if _, err = cert.Verify(opts); err != nil { - return errtype.NewDialError( - "failed to verify certificate", cn.String(), err, - ) - } - - certInstanceName := fmt.Sprintf("%s:%s", cn.Project(), cn.Name()) - if cert.Subject.CommonName != certInstanceName { - return errtype.NewDialError( - fmt.Sprintf( - "certificate had CN %q, expected %q", - cert.Subject.CommonName, certInstanceName, - ), - cn.String(), - nil, - ) - } - return nil + // Replace entire default TLS verification with our custom TLS + // verification defined in verifyPeerCertificateFunc(). This allows the + // connector to gracefully and securely handle deviations from standard TLS + // hostname validation in some existing Cloud SQL certificates. + InsecureSkipVerify: true, + VerifyPeerCertificate: verifyPeerCertificateFunc(serverName, c.ConnectionName, pool), } } diff --git a/internal/cloudsql/instance_test.go b/internal/cloudsql/instance_test.go index 6e437768..443a7ba3 100644 --- a/internal/cloudsql/instance_test.go +++ b/internal/cloudsql/instance_test.go @@ -165,14 +165,6 @@ func TestConnectionInfoTLSConfig(t *testing.T) { } got := ci.TLSConfig() - wantServerName := cn.String() - if got.ServerName != wantServerName { - t.Fatalf( - "ConnectInfo return unexpected server name in TLS Config, "+ - "want = %v, got = %v", - wantServerName, got.ServerName, - ) - } if got.MinVersion != tls.VersionTLS13 { t.Fatalf( @@ -403,7 +395,7 @@ func TestConnectionInfoTLSConfigForCAS(t *testing.T) { wantRootCAs.AddCert(subCACert) // Assemble a connection info with the raw and parsed client cert // and the self-signed server certificate - wantServerName := "testing dns name" + wantServerName := "db.example.com" ci := ConnectionInfo{ DNSName: wantServerName, ClientCertificate: tls.Certificate{ @@ -434,8 +426,8 @@ func TestConnectionInfoTLSConfigForCAS(t *testing.T) { if got.Certificates[0].Leaf != ci.ClientCertificate.Leaf { t.Fatal("leaf certificates do not match") } - if got.InsecureSkipVerify { - t.Fatal("InsecureSkipVerify is true, expected false") + if !got.InsecureSkipVerify { + t.Fatal("InsecureSkipVerify is false, expected true") } if !got.RootCAs.Equal(wantRootCAs) { t.Fatalf("unexpected root CAs, got %v, want %v", got.RootCAs, wantRootCAs) diff --git a/internal/cloudsql/tls_verify.go b/internal/cloudsql/tls_verify.go new file mode 100644 index 00000000..3acd68d8 --- /dev/null +++ b/internal/cloudsql/tls_verify.go @@ -0,0 +1,157 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cloudsql + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + + "cloud.google.com/go/cloudsqlconn/errtype" + "cloud.google.com/go/cloudsqlconn/instance" +) + +// verifyPeerCertificateFunc creates a VerifyPeerCertificate function with the +// custom TLS verification logic to gracefully and securely handle deviations +// from standard TLS hostname verification in existing Cloud SQL instance +// server certificates. +// +// This is the verification algorithm: +// +// 1. Verify the server cert CA, using the CA certs from the instance metadata. +// Reject the certificate if the CA is invalid. +// +// 2. Check that the server cert contains a SubjectAlternativeName matching the +// DNS name in the connector configuration OR the DNS Name from the instance +// metadata +// +// 3. If the SubjectAlternativeName does not match, and if the server cert +// Subject.CN field is not empty, check that the Subject.CN field contains +// the instance name. +// +// Reject the certificate if both the #2 SAN check and #3 CN checks fail. +// +// To summarize the deviations from standard TLS hostname verification: +// +// Historically, Cloud SQL creates server certificates with the instance name in +// the Subject.CN field in the format "my-project:my-instance". The connector is +// expected to check that the instance name that the connector was configured to +// dial matches the server certificate Subject.CN field. Thus, the Subject.CN +// field for most Cloud SQL instances does not contain a well-formed DNS Name. +// +// The default Go TLS hostname verification TLSConfig.serverName may be compared +// with the Subject.CN field if Subject.CN contains a well-formed DNS name. +// So the Cloud SQL server certs break the standard hostname verification in Go. +// See: +// - https://github.com/GoogleCloudPlatform/cloudsql-proxy/issues/194 +// - https://tip.golang.org/doc/go1.11#crypto/x509 +// +// Also, there are times when the instance metadata reports that an instance has +// a DNS name, but that DNS name does not yet appear in the SAN records of the +// server certificate. The client should fall back to validating the hostname +// using the instance name in the Subject.CN field. +func verifyPeerCertificateFunc( + serverName string, cn instance.ConnName, roots *x509.CertPool, +) func(certs [][]byte, chain [][]*x509.Certificate) error { + return func(rawCerts [][]byte, _ [][]*x509.Certificate) error { + if len(rawCerts) == 0 { + return errtype.NewDialError( + "no certificate to verify", cn.String(), nil, + ) + } + // Parse the raw certificates + certs := make([]*x509.Certificate, 0, len(rawCerts)) + var err error + for _, certBytes := range rawCerts { + cert, err := x509.ParseCertificate(certBytes) + if err != nil { + return errtype.NewDialError( + "failed to parse X.509 certificate", cn.String(), err, + ) + } + certs = append(certs, cert) + } + serverCert := certs[0] + + // Verify the validity of the certificate chain + _, err = serverCert.Verify(x509.VerifyOptions{ + Roots: roots, + }) + if err != nil { + err = &tls.CertificateVerificationError{ + UnverifiedCertificates: certs, + Err: err, + } + return errtype.NewDialError( + "failed to verify certificate", cn.String(), err, + ) + } + + var serverNameErr error + + if serverName == "" { + // The instance has no DNS name. + // Verify only the CN + return verifyCn(cn, serverCert) + } + + // The instance has a DNS name. + // First, verify the server hostname + serverNameErr = serverCert.VerifyHostname(serverName) + if serverNameErr != nil { + // If that failed, verify the CN field. + cnErr := verifyCn(cn, serverCert) + if cnErr != nil { + // If both failed, return the server hostname error. + serverNameErr = &tls.CertificateVerificationError{ + UnverifiedCertificates: certs, + Err: serverNameErr, + } + return serverNameErr + } + } + + // All checks passed + return nil + } +} + +func verifyCn(cn instance.ConnName, cert *x509.Certificate) error { + // Reject CN check if the certificate CN field is empty + if cert.Subject.CommonName == "" { + return errtype.NewDialError( + fmt.Sprintf( + "certificate CN was empty, expected %q", + cert.Subject.CommonName, + ), + cn.String(), + nil, + ) + } + + // Verify the CN field matches the instance name + certInstanceName := fmt.Sprintf("%s:%s", cn.Project(), cn.Name()) + if cert.Subject.CommonName != certInstanceName { + return errtype.NewDialError( + fmt.Sprintf( + "certificate had CN %q, expected %q", + cert.Subject.CommonName, certInstanceName, + ), + cn.String(), + nil, + ) + } + return nil +} diff --git a/internal/cloudsql/tls_verify_test.go b/internal/cloudsql/tls_verify_test.go new file mode 100644 index 00000000..873f3303 --- /dev/null +++ b/internal/cloudsql/tls_verify_test.go @@ -0,0 +1,151 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cloudsql + +import ( + "crypto/x509" + "fmt" + "testing" + "time" + + "cloud.google.com/go/cloudsqlconn/instance" + "cloud.google.com/go/cloudsqlconn/internal/mock" +) + +func TestVerifyCertificate(t *testing.T) { + tcs := []struct { + desc string + serverName string // verify input server dns name + icn string // verify input instance connection name + cn string // cert CN + san string // cert SAN + valid bool // wants validation to succeed + }{ + { + desc: "cn match", + icn: "myProject:myRegion:myInstance", + cn: "myProject:myInstance", + valid: true, + }, + { + desc: "cn no match", + icn: "myProject:myRegion:badInstance", + cn: "myProject:myInstance", + valid: false, + }, + { + desc: "cn empty", + icn: "myProject:myRegion:myInstance", + san: "db.example.com", + valid: false, + }, + { + desc: "san match", + serverName: "db.example.com", + icn: "myProject:myRegion:myInstance", + san: "db.example.com", + valid: true, + }, + { + desc: "san no match", + serverName: "bad.example.com", + icn: "myProject:myRegion:myInstance", + san: "db.example.com", + valid: false, + }, + { + desc: "san empty match", + serverName: "empty.example.com", + icn: "myProject:myRegion:myInstance", + cn: "", + valid: false, + }, + { + desc: "san match with cn present", + serverName: "db.example.com", + icn: "myProject:myRegion:myInstance", + san: "db.example.com", + cn: "myProject:myInstance", + valid: true, + }, + { + desc: "san no match fallback to cn", + serverName: "db.example.com", + icn: "myProject:myRegion:myInstance", + san: "other.example.com", + cn: "myProject:myInstance", + valid: true, + }, + { + desc: "san empty match fallback to cn", + serverName: "db.example.com", + icn: "myProject:myRegion:myInstance", + cn: "myProject:myInstance", + valid: true, + }, + { + desc: "san no match fallback to cn and fail", + serverName: "db.example.com", + icn: "myProject:myRegion:badInstance", + san: "other.example.com", + cn: "myProject:myInstance", + valid: false, + }, + } + + tlsCerts := mock.NewTLSCertificates("myProject", "myInstance", nil, time.Now().Add(time.Hour)) + + for _, tc := range tcs { + for _, useCAS := range []string{"legacy", "cas"} { + t.Run(fmt.Sprintf( + + "%s %s", tc.desc, useCAS), func(t *testing.T) { + var sans []string + if tc.san != "" { + sans = []string{tc.san} + } + var serverChain []*x509.Certificate + if useCAS == "cas" { + serverChain = tlsCerts.CreateCASServerChain(tc.cn, sans) + } else { + serverChain = tlsCerts.CreateServerChain(tc.cn, sans) + } + + icn, _ := instance.ParseConnName(tc.icn) + + serverChainRaw := make([][]byte, len(serverChain)) + for i, cert := range serverChain { + serverChainRaw[i] = cert.Raw + } + + roots := x509.NewCertPool() + for i := 1; i < len(serverChain); i++ { + roots.AddCert(serverChain[i]) + } + + verifyFunc := verifyPeerCertificateFunc(tc.serverName, icn, roots) + err := verifyFunc(serverChainRaw, nil) + + if err != nil && tc.valid { + t.Fatalf("want no error, got %v", err) + } + if err == nil && !tc.valid { + t.Fatal("want error, got no error") + } + + }) + } + } +} diff --git a/internal/mock/certs.go b/internal/mock/certs.go index 00d9c9cc..682b322f 100644 --- a/internal/mock/certs.go +++ b/internal/mock/certs.go @@ -74,8 +74,8 @@ func mustGenerateKey() *rsa.PrivateKey { return key } -// newTLSCertificates creates a new instance of the TLSCertificates. -func newTLSCertificates(projectName, instanceName string, sans []string, clientCertExpires time.Time) *TLSCertificates { +// NewTLSCertificates creates a new instance of the TLSCertificates. +func NewTLSCertificates(projectName, instanceName string, sans []string, clientCertExpires time.Time) *TLSCertificates { c := &TLSCertificates{ clientCertExpires: clientCertExpires, projectName: projectName, @@ -140,7 +140,7 @@ func mustBuildSignedCertificate( isCa bool, subject pkix.Name, subjectPublicKey *rsa.PrivateKey, - certificateIssuer pkix.Name, + issuerCert *x509.Certificate, issuerPrivateKey *rsa.PrivateKey, notAfter time.Time, subjectAlternativeNames []string) *x509.Certificate { @@ -155,7 +155,6 @@ func mustBuildSignedCertificate( Subject: subject, SubjectKeyId: generateSKI(&subjectPublicKey.PublicKey), AuthorityKeyId: generateSKI(&issuerPrivateKey.PublicKey), - Issuer: certificateIssuer, NotBefore: time.Now(), NotAfter: notAfter, IsCA: isCa, @@ -165,7 +164,7 @@ func mustBuildSignedCertificate( DNSNames: subjectAlternativeNames, } - certDerBytes, err := x509.CreateCertificate(rand.Reader, cert, cert, &subjectPublicKey.PublicKey, issuerPrivateKey) + certDerBytes, err := x509.CreateCertificate(rand.Reader, cert, issuerCert, &subjectPublicKey.PublicKey, issuerPrivateKey) if err != nil { panic(err) } @@ -238,7 +237,7 @@ func (ct *TLSCertificates) generateServerCertWithCn(cn string) *x509.Certificate false, name(cn), ct.serverKey, - serverCaSubject, + ct.serverCaCert, ct.serverCaKey, time.Now().Add(1*time.Hour), nil) } @@ -261,7 +260,32 @@ func (ct *TLSCertificates) serverChain(useStandardTLSValidation bool) []tls.Cert PrivateKey: ct.serverKey, Leaf: ct.casServerCertificate, }} +} +// CreateServerChain creates a legacy server certificate chain containing the +// CN and SAN fields. +func (ct *TLSCertificates) CreateServerChain(cn string, sans []string) []*x509.Certificate { + cert := mustBuildSignedCertificate( + false, + name(cn), + ct.serverKey, + ct.serverCaCert, + ct.serverCaKey, + time.Now().Add(1*time.Hour), sans) + return []*x509.Certificate{cert, ct.serverCaCert} +} + +// CreateCASServerChain creates a certificate chain containing the +// CN and SAN fields. +func (ct *TLSCertificates) CreateCASServerChain(cn string, sans []string) []*x509.Certificate { + cert := mustBuildSignedCertificate( + false, + name(cn), + ct.serverKey, + ct.serverIntermediateCaCert, + ct.serverIntermediateCaKey, + time.Now().Add(1*time.Hour), sans) + return []*x509.Certificate{cert, ct.serverIntermediateCaCert, ct.serverCaCert} } func (ct *TLSCertificates) clientCAPool() *x509.CertPool { clientCa := x509.NewCertPool() @@ -288,7 +312,7 @@ func (ct *TLSCertificates) rotateCA() { true, intermediateCaSubject, ct.serverIntermediateCaKey, - serverCaSubject, + ct.serverCaCert, ct.serverCaKey, oneYear, nil) @@ -298,7 +322,7 @@ func (ct *TLSCertificates) rotateCA() { false, name(""), ct.serverKey, - intermediateCaSubject, + ct.serverIntermediateCaCert, ct.serverIntermediateCaKey, oneYear, ct.sans) @@ -307,10 +331,10 @@ func (ct *TLSCertificates) rotateCA() { false, name(ct.projectName+":"+ct.instanceName), ct.serverKey, - serverCaSubject, + ct.serverCaCert, ct.serverCaKey, oneYear, - nil) + ct.sans) ct.rotateClientCA() } diff --git a/internal/mock/cloudsql.go b/internal/mock/cloudsql.go index 679ab1d7..144d02bc 100644 --- a/internal/mock/cloudsql.go +++ b/internal/mock/cloudsql.go @@ -54,8 +54,9 @@ type FakeCSQLInstance struct { // DNSName is the legacy field // DNSNames supersedes DNSName. - DNSName string - DNSNames []*sqladmin.DnsNameMapping + DNSName string + MissingSAN string + DNSNames []*sqladmin.DnsNameMapping useStandardTLSValidation bool serverCAMode string @@ -133,6 +134,15 @@ func WithDNS(dns string) FakeCSQLInstanceOption { } } +// WithMissingSAN will cause the omit this dns name +// from the server cert, even though it is in the metadata. +func WithMissingSAN(dns string) FakeCSQLInstanceOption { + return func(f *FakeCSQLInstance) { + f.MissingSAN = dns + } + +} + // WithDNSMapping adds the DnsNames records func WithDNSMapping(name, dnsScope, connectionType string) FakeCSQLInstanceOption { return func(f *FakeCSQLInstance) { @@ -228,17 +238,19 @@ func NewFakeCSQLInstance(project, region, name string, opts ...FakeCSQLInstanceO o(&f) } sanNames := make([]string, 0, 5) - if f.DNSName != "" { + if f.DNSName != "" && f.DNSName != f.MissingSAN { sanNames = append(sanNames, f.DNSName) } for _, dnm := range f.DNSNames { - sanNames = append(sanNames, dnm.Name) + if dnm.Name != f.MissingSAN { + sanNames = append(sanNames, dnm.Name) + } } if len(sanNames) > 0 { f.useStandardTLSValidation = true } - certs := newTLSCertificates(project, name, sanNames, f.certExpiry) + certs := NewTLSCertificates(project, name, sanNames, f.certExpiry) f.Key = certs.serverKey f.Cert = certs.serverCert