Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 0 additions & 36 deletions credentials/xds/xds.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ package xds
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net"
"sync/atomic"
"time"
Expand Down Expand Up @@ -138,40 +136,6 @@ func (c *credsImpl) ClientHandshake(ctx context.Context, authority string, rawCo
if err != nil {
return nil, nil, err
}
cfg.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
// Parse all raw certificates presented by the peer.
var certs []*x509.Certificate
for _, rc := range rawCerts {
cert, err := x509.ParseCertificate(rc)
if err != nil {
return err
}
certs = append(certs, cert)
}

// Build the intermediates list and verify that the leaf certificate
// is signed by one of the root certificates.
intermediates := x509.NewCertPool()
for _, cert := range certs[1:] {
intermediates.AddCert(cert)
}
opts := x509.VerifyOptions{
Roots: cfg.RootCAs,
Intermediates: intermediates,
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}
if _, err := certs[0].Verify(opts); err != nil {
return err
}
// The SANs sent by the MeshCA are encoded as SPIFFE IDs. We need to
// only look at the SANs on the leaf cert.
if cert := certs[0]; !hi.MatchingSANExists(cert) {
// TODO: Print the complete certificate once the x509 package
// supports a String() method on the Certificate type.
return fmt.Errorf("xds: received SANs {DNSNames: %v, EmailAddresses: %v, IPAddresses: %v, URIs: %v} do not match any of the accepted SANs", cert.DNSNames, cert.EmailAddresses, cert.IPAddresses, cert.URIs)
}
return nil
}

// Perform the TLS handshake with the tls.Config that we have. We run the
// actual Handshake() function in a goroutine because we need to respect the
Expand Down
186 changes: 165 additions & 21 deletions credentials/xds/xds_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,26 @@ import (
"time"
"unsafe"

"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/tls/certprovider"
icredentials "google.golang.org/grpc/internal/credentials"
"google.golang.org/grpc/internal/credentials/spiffe"
xdsinternal "google.golang.org/grpc/internal/credentials/xds"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/internal/xds/matcher"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/status"
"google.golang.org/grpc/testdata"
)

const (
defaultTestTimeout = 1 * time.Second
defaultTestShortTimeout = 10 * time.Millisecond
defaultTestCertSAN = "abc.test.example.com"
authority = "authority"
authority = "authority"
defaultTestCertSAN = "abc.test.example.com"
defaultTestCertSANSPIFFE = "*.test.google.fr"
defaultTestShortTimeout = 10 * time.Millisecond
defaultTestTimeout = 1 * time.Second
)

type s struct {
Expand All @@ -60,8 +64,54 @@ func Test(t *testing.T) {

// Helper function to create a real TLS client credentials which is used as
// fallback credentials from multiple tests.
func makeFallbackClientCreds(t *testing.T) credentials.TransportCredentials {
creds, err := credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), "x.test.example.com")
func makeFallbackClientCreds(t *testing.T, useSPIFFECreds bool, isMTLS bool) credentials.TransportCredentials {
var creds credentials.TransportCredentials
var err error
if useSPIFFECreds {
if isMTLS {
b, err := os.ReadFile(testdata.Path("spiffe_end2end/ca.pem"))
if err != nil {
t.Fatal(err)
}
cp := x509.NewCertPool()
if !cp.AppendCertsFromPEM(b) {
t.Fatalf("failed to append certificates")
}
cert, err := tls.LoadX509KeyPair(testdata.Path("spiffe_end2end/client_spiffe.pem"), testdata.Path("spiffe_end2end/client.key"))
if err != nil {
t.Fatal(err)
}
creds = credentials.NewTLS(&tls.Config{
ServerName: "x.test.example.com",
RootCAs: cp,
Certificates: []tls.Certificate{cert},
})
} else {
creds, err = credentials.NewClientTLSFromFile(testdata.Path("spiffe_end2end/ca.pem"), "x.test.example.com")
}
} else {
if isMTLS {
b, err := os.ReadFile(testdata.Path("x509/server_ca_cert.pem"))
if err != nil {
t.Fatal(err)
}
cp := x509.NewCertPool()
if !cp.AppendCertsFromPEM(b) {
t.Fatalf("failed to append certificates")
}
cert, err := tls.LoadX509KeyPair(testdata.Path("x509/client1_cert.pem"), testdata.Path("x509/client1_key.pem"))
if err != nil {
t.Fatal(err)
}
creds = credentials.NewTLS(&tls.Config{
ServerName: "x.test.example.com",
RootCAs: cp,
Certificates: []tls.Certificate{cert},
})
} else {
creds, err = credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), "x.test.example.com")
}
}
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -157,6 +207,25 @@ func testServerTLSHandshake(rawConn net.Conn) handshakeResult {
return handshakeResult{connState: conn.ConnectionState()}
}

// A handshake function which simulates a successful handshake without client
// authentication (server does not request for client certificate during the
// handshake here).
func testServerTLSHandshakeSPIFFE(rawConn net.Conn) handshakeResult {
cert, err := tls.LoadX509KeyPair(testdata.Path("spiffe_end2end/server_spiffe.pem"), testdata.Path("spiffe_end2end/server.key"))
if err != nil {
return handshakeResult{err: err}
}
cfg := &tls.Config{
Certificates: []tls.Certificate{cert},
NextProtos: []string{"h2"},
}
conn := tls.Server(rawConn, cfg)
if err := conn.Handshake(); err != nil {
return handshakeResult{err: err}
}
return handshakeResult{connState: conn.ConnectionState()}
}

// A handshake function which simulates a successful handshake with mutual
// authentication.
func testServerMutualTLSHandshake(rawConn net.Conn) handshakeResult {
Expand All @@ -173,6 +242,32 @@ func testServerMutualTLSHandshake(rawConn net.Conn) handshakeResult {
cfg := &tls.Config{
Certificates: []tls.Certificate{cert},
ClientCAs: roots,
ClientAuth: tls.RequireAnyClientCert,
}
conn := tls.Server(rawConn, cfg)
if err := conn.Handshake(); err != nil {
return handshakeResult{err: err}
}
return handshakeResult{connState: conn.ConnectionState()}
}

// A handshake function which simulates a successful handshake with mutual
// authentication.
func testServerMutualTLSHandshakeSPIFFE(rawConn net.Conn) handshakeResult {
cert, err := tls.LoadX509KeyPair(testdata.Path("spiffe_end2end/server_spiffe.pem"), testdata.Path("spiffe_end2end/server.key"))
if err != nil {
return handshakeResult{err: err}
}
pemData, err := os.ReadFile(testdata.Path("x509/client_ca_cert.pem"))
if err != nil {
return handshakeResult{err: err}
}
roots := x509.NewCertPool()
roots.AppendCertsFromPEM(pemData)
cfg := &tls.Config{
Certificates: []tls.Certificate{cert},
ClientCAs: roots,
ClientAuth: tls.RequireAndVerifyClientCert,
}
conn := tls.Server(rawConn, cfg)
if err := conn.Handshake(); err != nil {
Expand Down Expand Up @@ -218,6 +313,18 @@ func makeRootProvider(t *testing.T, caPath string) *fakeProvider {
return &fakeProvider{km: &certprovider.KeyMaterial{Roots: roots}}
}

func makeSPIFFEBundleProvider(t *testing.T, spiffeBundlePath string) *fakeProvider {
bytes, err := os.ReadFile(testdata.Path(spiffeBundlePath))
if err != nil {
t.Fatal(err)
}
spiffeBundle, err := spiffe.BundleMapFromBytes(bytes)
if err != nil {
t.Fatal(err)
}
return &fakeProvider{km: &certprovider.KeyMaterial{SPIFFEBundleMap: spiffeBundle}}
}

// newTestContextWithHandshakeInfo returns a copy of parent with HandshakeInfo
// context value added to it.
func newTestContextWithHandshakeInfo(parent context.Context, root, identity certprovider.Provider, sanExactMatch string) context.Context {
Expand Down Expand Up @@ -294,7 +401,7 @@ func (s) TestClientCredsWithoutFallback(t *testing.T) {
// HandshakeInfo is invalid because it does not contain the expected certificate
// providers.
func (s) TestClientCredsInvalidHandshakeInfo(t *testing.T) {
opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t, false, false)}
creds, err := NewClientCredentials(opts)
if err != nil {
t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
Expand All @@ -311,7 +418,7 @@ func (s) TestClientCredsInvalidHandshakeInfo(t *testing.T) {
// TestClientCredsProviderFailure verifies the cases where an expected
// certificate provider is missing in the HandshakeInfo value in the context.
func (s) TestClientCredsProviderFailure(t *testing.T) {
opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t, false, false)}
creds, err := NewClientCredentials(opts)
if err != nil {
t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
Expand Down Expand Up @@ -353,6 +460,8 @@ func (s) TestClientCredsSuccess(t *testing.T) {
desc string
handshakeFunc testHandshakeFunc
handshakeInfoCtx func(ctx context.Context) context.Context
useSPIFFECreds bool
isMTLS bool
}{
{
desc: "fallback",
Expand All @@ -376,22 +485,40 @@ func (s) TestClientCredsSuccess(t *testing.T) {
handshakeInfoCtx: func(ctx context.Context) context.Context {
return newTestContextWithHandshakeInfo(ctx, makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/server1_cert.pem", "x509/server1_key.pem"), defaultTestCertSAN)
},
isMTLS: true,
},
{
desc: "mTLS with no acceptedSANs specified",
handshakeFunc: testServerMutualTLSHandshake,
handshakeInfoCtx: func(ctx context.Context) context.Context {
return newTestContextWithHandshakeInfo(ctx, makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/server1_cert.pem", "x509/server1_key.pem"), "")
},
isMTLS: true,
},
{
desc: "SPIFFE TLS",
handshakeFunc: testServerTLSHandshakeSPIFFE,
handshakeInfoCtx: func(ctx context.Context) context.Context {
return newTestContextWithHandshakeInfo(ctx, makeSPIFFEBundleProvider(t, "spiffe_end2end/client_spiffebundle.json"), nil, defaultTestCertSANSPIFFE)
},
useSPIFFECreds: true,
},
{
desc: "SPIFFE mTLS",
handshakeFunc: testServerMutualTLSHandshakeSPIFFE,
handshakeInfoCtx: func(ctx context.Context) context.Context {
return newTestContextWithHandshakeInfo(ctx, makeSPIFFEBundleProvider(t, "spiffe_end2end/server_spiffebundle.json"), makeIdentityProvider(t, "spiffe_end2end/server_spiffe.pem", "spiffe_end2end/server.key"), defaultTestCertSANSPIFFE)
},
useSPIFFECreds: true,
isMTLS: true,
},
}

for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
ts := newTestServerWithHandshakeFunc(test.handshakeFunc)
defer ts.stop()

opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t, test.useSPIFFECreds, test.isMTLS)}
creds, err := NewClientCredentials(opts)
if err != nil {
t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
Expand Down Expand Up @@ -428,7 +555,7 @@ func (s) TestClientCredsHandshakeTimeout(t *testing.T) {
})
defer ts.stop()

opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t, false, false)}
creds, err := NewClientCredentials(opts)
if err != nil {
t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
Expand Down Expand Up @@ -464,12 +591,14 @@ func (s) TestClientCredsHandshakeTimeout(t *testing.T) {

// TestClientCredsHandshakeFailure verifies different handshake failure cases.
func (s) TestClientCredsHandshakeFailure(t *testing.T) {
const wantErrCode = codes.Unknown
tests := []struct {
desc string
handshakeFunc testHandshakeFunc
rootProvider certprovider.Provider
san string
wantErr string
desc string
handshakeFunc testHandshakeFunc
rootProvider certprovider.Provider
san string
useSPIFFECreds bool
wantErr string
}{
{
desc: "cert validation failure",
Expand All @@ -485,14 +614,22 @@ func (s) TestClientCredsHandshakeFailure(t *testing.T) {
san: "bad-san",
wantErr: "do not match any of the accepted SANs",
},
{
desc: "SPIFFE Bundle validation failure",
handshakeFunc: testServerTLSHandshakeSPIFFE,
rootProvider: makeSPIFFEBundleProvider(t, "spiffe_end2end/server_spiffebundle.json"),
san: defaultTestCertSANSPIFFE,
useSPIFFECreds: true,
wantErr: "no bundle found for peer certificates",
},
}

for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
ts := newTestServerWithHandshakeFunc(test.handshakeFunc)
defer ts.stop()

opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t, false, false)}
creds, err := NewClientCredentials(opts)
if err != nil {
t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
Expand All @@ -507,8 +644,15 @@ func (s) TestClientCredsHandshakeFailure(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
ctx = newTestContextWithHandshakeInfo(ctx, test.rootProvider, nil, test.san)
if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil || !strings.Contains(err.Error(), test.wantErr) {
t.Fatalf("ClientHandshake() returned %q, wantErr %q", err, test.wantErr)
_, _, err = creds.ClientHandshake(ctx, authority, conn)
if err == nil {
t.Fatalf("ClientHandshake() got no error, want error to contain %v", test.wantErr)
}
if !strings.Contains(err.Error(), test.wantErr) {
t.Fatalf("ClientHandshake() got error %v, want error to contain %v", err, test.wantErr)
}
if status.Code(err) != wantErrCode {
t.Fatalf("ClientHandshake() got error code %v, want error code %v", status.Code(err), wantErrCode)
}
})
}
Expand All @@ -523,7 +667,7 @@ func (s) TestClientCredsProviderSwitch(t *testing.T) {
ts := newTestServerWithHandshakeFunc(testServerTLSHandshake)
defer ts.stop()

opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t, false, false)}
creds, err := NewClientCredentials(opts)
if err != nil {
t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
Expand Down Expand Up @@ -581,7 +725,7 @@ func (s) TestClientCredsProviderSwitch(t *testing.T) {

// TestClientClone verifies the Clone() method on client credentials.
func (s) TestClientClone(t *testing.T) {
opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t, false, false)}
orig, err := NewClientCredentials(opts)
if err != nil {
t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
Expand Down
Loading
Loading