Skip to content

Commit 5ffe0ef

Browse files
authored
advancedtls: populate verified chains when using custom buildVerifyFunc (#7181)
* populate verified chains when using custom buildVerifyFunc
1 parent 1db6590 commit 5ffe0ef

File tree

2 files changed

+92
-6
lines changed

2 files changed

+92
-6
lines changed

security/advancedtls/advancedtls.go

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ import (
4141
credinternal "google.golang.org/grpc/internal/credentials"
4242
)
4343

44+
type CertificateChains [][]*x509.Certificate
45+
4446
// HandshakeVerificationInfo contains information about a handshake needed for
4547
// verification for use when implementing the `PostHandshakeVerificationFunc`
4648
// The fields in this struct are read-only.
@@ -53,7 +55,7 @@ type HandshakeVerificationInfo struct {
5355
RawCerts [][]byte
5456
// The verification chain obtained by checking peer RawCerts against the
5557
// trust certificate bundle(s), if applicable.
56-
VerifiedChains [][]*x509.Certificate
58+
VerifiedChains CertificateChains
5759
// The leaf certificate sent from peer, if choosing to verify the peer
5860
// certificate(s) and that verification passed. This field would be nil if
5961
// either user chose not to verify or the verification failed.
@@ -552,7 +554,8 @@ func (c *advancedTLSCreds) ClientHandshake(ctx context.Context, authority string
552554
if cfg.ServerName == "" {
553555
cfg.ServerName = authority
554556
}
555-
cfg.VerifyPeerCertificate = buildVerifyFunc(c, cfg.ServerName, rawConn)
557+
peerVerifiedChains := CertificateChains{}
558+
cfg.VerifyPeerCertificate = buildVerifyFunc(c, cfg.ServerName, rawConn, &peerVerifiedChains)
556559
conn := tls.Client(rawConn, cfg)
557560
errChannel := make(chan error, 1)
558561
go func() {
@@ -576,12 +579,14 @@ func (c *advancedTLSCreds) ClientHandshake(ctx context.Context, authority string
576579
},
577580
}
578581
info.SPIFFEID = credinternal.SPIFFEIDFromState(conn.ConnectionState())
582+
info.State.VerifiedChains = peerVerifiedChains
579583
return credinternal.WrapSyscallConn(rawConn, conn), info, nil
580584
}
581585

582586
func (c *advancedTLSCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
583587
cfg := credinternal.CloneTLSConfig(c.config)
584-
cfg.VerifyPeerCertificate = buildVerifyFunc(c, "", rawConn)
588+
peerVerifiedChains := CertificateChains{}
589+
cfg.VerifyPeerCertificate = buildVerifyFunc(c, "", rawConn, &peerVerifiedChains)
585590
conn := tls.Server(rawConn, cfg)
586591
if err := conn.Handshake(); err != nil {
587592
conn.Close()
@@ -594,6 +599,7 @@ func (c *advancedTLSCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credenti
594599
},
595600
}
596601
info.SPIFFEID = credinternal.SPIFFEIDFromState(conn.ConnectionState())
602+
info.State.VerifiedChains = peerVerifiedChains
597603
return credinternal.WrapSyscallConn(rawConn, conn), info, nil
598604
}
599605

@@ -618,9 +624,15 @@ func (c *advancedTLSCreds) OverrideServerName(serverNameOverride string) error {
618624
// 1. does not have a good support on root cert reloading.
619625
// 2. will ignore basic certificate check when setting InsecureSkipVerify
620626
// to true.
627+
//
628+
// peerVerifiedChains(output param): verified chain of certs from leaf to the
629+
// trust cert that the peer trusts.
630+
// 1. For server it is, client certs + Root ca that the server trusts
631+
// 2. For client it is, server certs + Root ca that the client trusts
621632
func buildVerifyFunc(c *advancedTLSCreds,
622633
serverName string,
623-
rawConn net.Conn) func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
634+
rawConn net.Conn,
635+
peerVerifiedChains *CertificateChains) func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
624636
return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
625637
chains := verifiedChains
626638
var leafCert *x509.Certificate
@@ -684,7 +696,7 @@ func buildVerifyFunc(c *advancedTLSCreds,
684696
if c.revocationOptions != nil {
685697
verifiedChains := chains
686698
if verifiedChains == nil {
687-
verifiedChains = [][]*x509.Certificate{rawCertList}
699+
verifiedChains = CertificateChains{rawCertList}
688700
}
689701
if err := checkChainRevocation(verifiedChains, *c.revocationOptions); err != nil {
690702
return err
@@ -698,8 +710,11 @@ func buildVerifyFunc(c *advancedTLSCreds,
698710
VerifiedChains: chains,
699711
Leaf: leafCert,
700712
})
701-
return err
713+
if err != nil {
714+
return err
715+
}
702716
}
717+
*peerVerifiedChains = chains
703718
return nil
704719
}
705720
}

security/advancedtls/advancedtls_test.go

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
package advancedtls
2020

2121
import (
22+
"bytes"
2223
"context"
2324
"crypto/tls"
2425
"crypto/x509"
@@ -949,6 +950,76 @@ func (s) TestClientServerHandshake(t *testing.T) {
949950
t.Fatalf("c.ClientHandshake(_, %v, _) = %v, want %v.", lisAddr,
950951
clientAuthInfo, serverAuthInfo)
951952
}
953+
serverVerifiedChains := serverAuthInfo.(credentials.TLSInfo).State.VerifiedChains
954+
if test.serverMutualTLS && !test.serverExpectError {
955+
if len(serverVerifiedChains) == 0 {
956+
t.Fatalf("server verified chains is empty")
957+
}
958+
var clientCert *tls.Certificate
959+
if len(test.clientCert) > 0 {
960+
clientCert = &test.clientCert[0]
961+
} else if test.clientGetCert != nil {
962+
cert, _ := test.clientGetCert(&tls.CertificateRequestInfo{})
963+
clientCert = cert
964+
} else if test.clientIdentityProvider != nil {
965+
km, _ := test.clientIdentityProvider.KeyMaterial(context.TODO())
966+
clientCert = &km.Certs[0]
967+
}
968+
if !bytes.Equal((*serverVerifiedChains[0][0]).Raw, clientCert.Certificate[0]) {
969+
t.Fatal("server verifiedChains leaf cert doesn't match client cert")
970+
}
971+
972+
var serverRoot *x509.CertPool
973+
if test.serverRoot != nil {
974+
serverRoot = test.serverRoot
975+
} else if test.serverGetRoot != nil {
976+
result, _ := test.serverGetRoot(&GetRootCAsParams{})
977+
serverRoot = result.TrustCerts
978+
} else if test.serverRootProvider != nil {
979+
km, _ := test.serverRootProvider.KeyMaterial(context.TODO())
980+
serverRoot = km.Roots
981+
}
982+
serverVerifiedChainsCp := x509.NewCertPool()
983+
serverVerifiedChainsCp.AddCert(serverVerifiedChains[0][len(serverVerifiedChains[0])-1])
984+
if !serverVerifiedChainsCp.Equal(serverRoot) {
985+
t.Fatalf("server verified chain hierarchy doesn't match")
986+
}
987+
}
988+
clientVerifiedChains := clientAuthInfo.(credentials.TLSInfo).State.VerifiedChains
989+
if test.serverMutualTLS && !test.clientExpectHandshakeError {
990+
if len(clientVerifiedChains) == 0 {
991+
t.Fatalf("client verified chains is empty")
992+
}
993+
var serverCert *tls.Certificate
994+
if len(test.serverCert) > 0 {
995+
serverCert = &test.serverCert[0]
996+
} else if test.serverGetCert != nil {
997+
cert, _ := test.serverGetCert(&tls.ClientHelloInfo{})
998+
serverCert = cert[0]
999+
} else if test.serverIdentityProvider != nil {
1000+
km, _ := test.serverIdentityProvider.KeyMaterial(context.TODO())
1001+
serverCert = &km.Certs[0]
1002+
}
1003+
if !bytes.Equal((*clientVerifiedChains[0][0]).Raw, serverCert.Certificate[0]) {
1004+
t.Fatal("client verifiedChains leaf cert doesn't match server cert")
1005+
}
1006+
1007+
var clientRoot *x509.CertPool
1008+
if test.clientRoot != nil {
1009+
clientRoot = test.clientRoot
1010+
} else if test.clientGetRoot != nil {
1011+
result, _ := test.clientGetRoot(&GetRootCAsParams{})
1012+
clientRoot = result.TrustCerts
1013+
} else if test.clientRootProvider != nil {
1014+
km, _ := test.clientRootProvider.KeyMaterial(context.TODO())
1015+
clientRoot = km.Roots
1016+
}
1017+
clientVerifiedChainsCp := x509.NewCertPool()
1018+
clientVerifiedChainsCp.AddCert(clientVerifiedChains[0][len(clientVerifiedChains[0])-1])
1019+
if !clientVerifiedChainsCp.Equal(clientRoot) {
1020+
t.Fatalf("client verified chain hierarchy doesn't match")
1021+
}
1022+
}
9521023
})
9531024
}
9541025
}

0 commit comments

Comments
 (0)