@@ -20,9 +20,15 @@ package credentials_test
2020
2121import (
2222 "context"
23+ "crypto/rand"
24+ "crypto/rsa"
2325 "crypto/tls"
26+ "crypto/x509"
27+ "crypto/x509/pkix"
2428 "fmt"
29+ "math/big"
2530 "net"
31+ "strings"
2632 "testing"
2733 "time"
2834
@@ -365,3 +371,168 @@ func (s) TestCorrectAuthorityWithCustomCreds(t *testing.T) {
365371 t .Fatalf ("EmptyCall() returned status %v, want %v" , status .Code (err ), codes .OK )
366372 }
367373}
374+
375+ // TestAuthorityOverrideWithCertChain tests that the authority being used to
376+ // override per-RPC authority is validated against the leaf certificate only
377+ // and not against the intermediate certificates.
378+ func (s ) TestAuthorityOverrideWithCertChain (t * testing.T ) {
379+ rootCert , certChain , leafKey := generateCertChain (t , "root.example.com" , "intermediate.example.com" , "*.leaf.example.com" )
380+
381+ // Construct server credentials from leaf and intermediate certificates.
382+ serverCert := tls.Certificate {
383+ Certificate : [][]byte {certChain [0 ].Raw , certChain [1 ].Raw },
384+ PrivateKey : leafKey ,
385+ }
386+ serverCreds := credentials .NewServerTLSFromCert (& serverCert )
387+
388+ // Create client credentials trusting the Root CA.
389+ certPool := x509 .NewCertPool ()
390+ certPool .AddCert (rootCert )
391+ clientCreds := credentials .NewTLS (& tls.Config {
392+ RootCAs : certPool ,
393+ ServerName : "test1.leaf.example.com" ,
394+ })
395+
396+ tests := []struct {
397+ name string
398+ authority string
399+ wantCode codes.Code
400+ wantErr string
401+ }{
402+ {
403+ name : "AuthorityMatchesIntermediate" ,
404+ authority : "intermediate.example.com" ,
405+ wantCode : codes .Unavailable ,
406+ wantErr : "failed to validate authority" ,
407+ },
408+ {
409+ name : "AuthorityMatchesLeaf" ,
410+ authority : "test2.leaf.example.com" ,
411+ wantCode : codes .OK ,
412+ },
413+ }
414+
415+ for _ , tt := range tests {
416+ t .Run (tt .name , func (t * testing.T ) {
417+ // Setup and start the stub server.
418+ ss := & stubserver.StubServer {
419+ EmptyCallF : func (ctx context.Context , _ * testpb.Empty ) (* testpb.Empty , error ) {
420+ if err := authorityChecker (ctx , tt .authority ); err != nil {
421+ return nil , err
422+ }
423+ return & testpb.Empty {}, nil
424+ },
425+ }
426+ if err := ss .StartServer (grpc .Creds (serverCreds )); err != nil {
427+ t .Fatalf ("failed to start server: %v" , err )
428+ }
429+ defer ss .Stop ()
430+
431+ cc , err := grpc .NewClient (ss .Address , grpc .WithTransportCredentials (clientCreds ))
432+ if err != nil {
433+ t .Fatalf ("grpc.NewClient(%q) = %v" , ss .Address , err )
434+ }
435+ defer cc .Close ()
436+
437+ ctx , cancel := context .WithTimeout (context .Background (), defaultTestTimeout )
438+ defer cancel ()
439+
440+ _ , err = testgrpc .NewTestServiceClient (cc ).EmptyCall (ctx , & testpb.Empty {}, grpc .CallAuthority (tt .authority ))
441+ if got := status .Code (err ); got != tt .wantCode {
442+ t .Fatalf ("EmptyCall() with authority %q: got code %v, want %v" , tt .authority , got , tt .wantCode )
443+ }
444+ if tt .wantErr != "" && (err == nil || ! strings .Contains (err .Error (), tt .wantErr )) {
445+ t .Fatalf ("EmptyCall() with authority %q: expected error to contain %q, got %v" , tt .authority , tt .wantErr , err )
446+ }
447+ })
448+ }
449+ }
450+
451+ // certConfig defines the configuration for generating a certificate.
452+ type certConfig struct {
453+ commonName string
454+ dnsNames []string
455+ isCA bool
456+ serial int64
457+ parentCert * x509.Certificate
458+ parentKey * rsa.PrivateKey
459+ }
460+
461+ // createCertificate generates a certificate based on the provided certConfig.
462+ // It creates self-signed certificates if parentCert is nil otherwise it creates
463+ // certificates signed by a parent certificate.
464+ func createCertificate (t * testing.T , cfg certConfig ) (* x509.Certificate , * rsa.PrivateKey ) {
465+ t .Helper ()
466+
467+ key , err := rsa .GenerateKey (rand .Reader , 2048 )
468+ if err != nil {
469+ t .Fatal (err )
470+ }
471+
472+ now := time .Now ()
473+ tmpl := & x509.Certificate {
474+ SerialNumber : big .NewInt (cfg .serial ),
475+ Subject : pkix.Name {CommonName : cfg .commonName },
476+ DNSNames : cfg .dnsNames ,
477+ NotBefore : now .Add (- time .Hour ),
478+ NotAfter : now .Add (time .Hour ),
479+ BasicConstraintsValid : true ,
480+ IsCA : cfg .isCA ,
481+ }
482+
483+ // If no parent is provided, the certificate is self-signed
484+ signingCert := cfg .parentCert
485+ signingKey := cfg .parentKey
486+ if signingCert == nil {
487+ signingCert = tmpl
488+ signingKey = key
489+ }
490+
491+ der , err := x509 .CreateCertificate (rand .Reader , tmpl , signingCert , key .Public (), signingKey )
492+ if err != nil {
493+ t .Fatal (err )
494+ }
495+
496+ cert , err := x509 .ParseCertificate (der )
497+ if err != nil {
498+ t .Fatal (err )
499+ }
500+
501+ return cert , key
502+ }
503+
504+ // generateCertChain creates a 3 certificate chain (Root -> Intermediate ->
505+ // Leaf). It returns the root certificate, a slice containing the leaf and
506+ // intermediate certificates in the order [leaf, intermediate], and the private
507+ // key for the leaf certificate.
508+ func generateCertChain (t * testing.T , rootName , interName , leafName string ) (root * x509.Certificate , chain []* x509.Certificate , leafKey * rsa.PrivateKey ) {
509+ t .Helper ()
510+
511+ rootCfg := certConfig {
512+ commonName : rootName ,
513+ isCA : true ,
514+ }
515+ root , rootKey := createCertificate (t , rootCfg )
516+
517+ interCfg := certConfig {
518+ commonName : interName ,
519+ dnsNames : []string {interName },
520+ isCA : true ,
521+ serial : 2 ,
522+ parentCert : root ,
523+ parentKey : rootKey ,
524+ }
525+ intermediate , interKey := createCertificate (t , interCfg )
526+
527+ leafCfg := certConfig {
528+ commonName : leafName ,
529+ dnsNames : []string {leafName },
530+ isCA : false ,
531+ serial : 3 ,
532+ parentCert : intermediate ,
533+ parentKey : interKey ,
534+ }
535+ leaf , leafKey := createCertificate (t , leafCfg )
536+
537+ return root , []* x509.Certificate {leaf , intermediate }, leafKey
538+ }
0 commit comments