Skip to content

Commit 69b542a

Browse files
credentials/tls: verify overwritten authority against only leaf certificate (#8831)
Fixes: #8721 This PR changes the function that verifies the per RPC authority being overwritten only against the peer's leaf certificate as opposed to checking against the whole chain of peer certificates which it was doing earlier. RELEASE NOTES: * credentials/tls: Fixes a bug where per-RPC authority verification was performed against the entire peer certificate chain instead of strictly checking the leaf certificate.
1 parent e6ca417 commit 69b542a

File tree

2 files changed

+179
-10
lines changed

2 files changed

+179
-10
lines changed

credentials/credentials_ext_test.go

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,15 @@ package credentials_test
2020

2121
import (
2222
"context"
23+
"crypto/rand"
24+
"crypto/rsa"
2325
"crypto/tls"
26+
"crypto/x509"
27+
"crypto/x509/pkix"
2428
"fmt"
29+
"math/big"
2530
"net"
31+
"strings"
2632
"testing"
2733
"time"
2834

@@ -365,3 +371,168 @@ func (s) TestCorrectAuthorityWithCustomCreds(t *testing.T) {
365371
t.Fatalf("EmptyCall() returned status %v, want %v", status.Code(err), codes.OK)
366372
}
367373
}
374+
375+
// TestAuthorityOverrideWithCertChain tests that the authority being used to
376+
// override per-RPC authority is validated against the leaf certificate only
377+
// and not against the intermediate certificates.
378+
func (s) TestAuthorityOverrideWithCertChain(t *testing.T) {
379+
rootCert, certChain, leafKey := generateCertChain(t, "root.example.com", "intermediate.example.com", "*.leaf.example.com")
380+
381+
// Construct server credentials from leaf and intermediate certificates.
382+
serverCert := tls.Certificate{
383+
Certificate: [][]byte{certChain[0].Raw, certChain[1].Raw},
384+
PrivateKey: leafKey,
385+
}
386+
serverCreds := credentials.NewServerTLSFromCert(&serverCert)
387+
388+
// Create client credentials trusting the Root CA.
389+
certPool := x509.NewCertPool()
390+
certPool.AddCert(rootCert)
391+
clientCreds := credentials.NewTLS(&tls.Config{
392+
RootCAs: certPool,
393+
ServerName: "test1.leaf.example.com",
394+
})
395+
396+
tests := []struct {
397+
name string
398+
authority string
399+
wantCode codes.Code
400+
wantErr string
401+
}{
402+
{
403+
name: "AuthorityMatchesIntermediate",
404+
authority: "intermediate.example.com",
405+
wantCode: codes.Unavailable,
406+
wantErr: "failed to validate authority",
407+
},
408+
{
409+
name: "AuthorityMatchesLeaf",
410+
authority: "test2.leaf.example.com",
411+
wantCode: codes.OK,
412+
},
413+
}
414+
415+
for _, tt := range tests {
416+
t.Run(tt.name, func(t *testing.T) {
417+
// Setup and start the stub server.
418+
ss := &stubserver.StubServer{
419+
EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
420+
if err := authorityChecker(ctx, tt.authority); err != nil {
421+
return nil, err
422+
}
423+
return &testpb.Empty{}, nil
424+
},
425+
}
426+
if err := ss.StartServer(grpc.Creds(serverCreds)); err != nil {
427+
t.Fatalf("failed to start server: %v", err)
428+
}
429+
defer ss.Stop()
430+
431+
cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(clientCreds))
432+
if err != nil {
433+
t.Fatalf("grpc.NewClient(%q) = %v", ss.Address, err)
434+
}
435+
defer cc.Close()
436+
437+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
438+
defer cancel()
439+
440+
_, err = testgrpc.NewTestServiceClient(cc).EmptyCall(ctx, &testpb.Empty{}, grpc.CallAuthority(tt.authority))
441+
if got := status.Code(err); got != tt.wantCode {
442+
t.Fatalf("EmptyCall() with authority %q: got code %v, want %v", tt.authority, got, tt.wantCode)
443+
}
444+
if tt.wantErr != "" && (err == nil || !strings.Contains(err.Error(), tt.wantErr)) {
445+
t.Fatalf("EmptyCall() with authority %q: expected error to contain %q, got %v", tt.authority, tt.wantErr, err)
446+
}
447+
})
448+
}
449+
}
450+
451+
// certConfig defines the configuration for generating a certificate.
452+
type certConfig struct {
453+
commonName string
454+
dnsNames []string
455+
isCA bool
456+
serial int64
457+
parentCert *x509.Certificate
458+
parentKey *rsa.PrivateKey
459+
}
460+
461+
// createCertificate generates a certificate based on the provided certConfig.
462+
// It creates self-signed certificates if parentCert is nil otherwise it creates
463+
// certificates signed by a parent certificate.
464+
func createCertificate(t *testing.T, cfg certConfig) (*x509.Certificate, *rsa.PrivateKey) {
465+
t.Helper()
466+
467+
key, err := rsa.GenerateKey(rand.Reader, 2048)
468+
if err != nil {
469+
t.Fatal(err)
470+
}
471+
472+
now := time.Now()
473+
tmpl := &x509.Certificate{
474+
SerialNumber: big.NewInt(cfg.serial),
475+
Subject: pkix.Name{CommonName: cfg.commonName},
476+
DNSNames: cfg.dnsNames,
477+
NotBefore: now.Add(-time.Hour),
478+
NotAfter: now.Add(time.Hour),
479+
BasicConstraintsValid: true,
480+
IsCA: cfg.isCA,
481+
}
482+
483+
// If no parent is provided, the certificate is self-signed
484+
signingCert := cfg.parentCert
485+
signingKey := cfg.parentKey
486+
if signingCert == nil {
487+
signingCert = tmpl
488+
signingKey = key
489+
}
490+
491+
der, err := x509.CreateCertificate(rand.Reader, tmpl, signingCert, key.Public(), signingKey)
492+
if err != nil {
493+
t.Fatal(err)
494+
}
495+
496+
cert, err := x509.ParseCertificate(der)
497+
if err != nil {
498+
t.Fatal(err)
499+
}
500+
501+
return cert, key
502+
}
503+
504+
// generateCertChain creates a 3 certificate chain (Root -> Intermediate ->
505+
// Leaf). It returns the root certificate, a slice containing the leaf and
506+
// intermediate certificates in the order [leaf, intermediate], and the private
507+
// key for the leaf certificate.
508+
func generateCertChain(t *testing.T, rootName, interName, leafName string) (root *x509.Certificate, chain []*x509.Certificate, leafKey *rsa.PrivateKey) {
509+
t.Helper()
510+
511+
rootCfg := certConfig{
512+
commonName: rootName,
513+
isCA: true,
514+
}
515+
root, rootKey := createCertificate(t, rootCfg)
516+
517+
interCfg := certConfig{
518+
commonName: interName,
519+
dnsNames: []string{interName},
520+
isCA: true,
521+
serial: 2,
522+
parentCert: root,
523+
parentKey: rootKey,
524+
}
525+
intermediate, interKey := createCertificate(t, interCfg)
526+
527+
leafCfg := certConfig{
528+
commonName: leafName,
529+
dnsNames: []string{leafName},
530+
isCA: false,
531+
serial: 3,
532+
parentCert: intermediate,
533+
parentKey: interKey,
534+
}
535+
leaf, leafKey := createCertificate(t, leafCfg)
536+
537+
return root, []*x509.Certificate{leaf, intermediate}, leafKey
538+
}

credentials/tls.go

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import (
2222
"context"
2323
"crypto/tls"
2424
"crypto/x509"
25-
"errors"
2625
"fmt"
2726
"net"
2827
"net/url"
@@ -52,22 +51,21 @@ func (t TLSInfo) AuthType() string {
5251
}
5352

5453
// ValidateAuthority validates the provided authority being used to override the
55-
// :authority header by verifying it against the peer certificates. It returns a
54+
// :authority header by verifying it against the peer certificate. It returns a
5655
// non-nil error if the validation fails.
5756
func (t TLSInfo) ValidateAuthority(authority string) error {
58-
var errs []error
5957
host, _, err := net.SplitHostPort(authority)
6058
if err != nil {
6159
host = authority
6260
}
63-
for _, cert := range t.State.PeerCertificates {
64-
var err error
65-
if err = cert.VerifyHostname(host); err == nil {
66-
return nil
67-
}
68-
errs = append(errs, err)
61+
62+
// Verify authority against the leaf certificate.
63+
if len(t.State.PeerCertificates) == 0 {
64+
// This is not expected to happen as the TLS handshake has already
65+
// completed and should have populated PeerCertificates.
66+
return fmt.Errorf("credentials: no peer certificates found to verify authority %q", host)
6967
}
70-
return fmt.Errorf("credentials: invalid authority %q: %v", authority, errors.Join(errs...))
68+
return t.State.PeerCertificates[0].VerifyHostname(host)
7169
}
7270

7371
// cipherSuiteLookup returns the string version of a TLS cipher suite ID.

0 commit comments

Comments
 (0)