Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions credentials/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,20 @@ type AuthInfo interface {
AuthType() string
}

// AuthorityValidator validates the authority used to override the `:authority`
// header. This is an optional interface that implementations of AuthInfo can
// implement if the credentials support per-RPC authority overrides. It is
// invoked when the application attempts to override the HTTP/2 `:authority`
// header using the CallAuthority call option.
type AuthorityValidator interface {
// ValidateAuthority checks the authority value used to override the
// `:authority` header. The authority parameter is the override value
// provided by the application via the CallAuthority option. This value
// typically corresponds to the server hostname or endpoint the RPC is
// targeting. It returns non-nil error if the validation fails.
ValidateAuthority(authority string) error
}

// ErrConnDispatched indicates that rawConn has been dispatched out of gRPC
// and the caller should not close rawConn.
var ErrConnDispatched = errors.New("credentials: rawConn is dispatched out of gRPC")
Expand Down
333 changes: 333 additions & 0 deletions credentials/credentials_ext_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,333 @@
/*
*
* Copyright 2025 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package credentials_test

import (
"context"
"crypto/tls"
"fmt"
"net"
"testing"
"time"

"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/internal/stubserver"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/grpc/testdata"

testgrpc "google.golang.org/grpc/interop/grpc_testing"
testpb "google.golang.org/grpc/interop/grpc_testing"
)

func authorityChecker(ctx context.Context, wantAuthority string) error {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return status.Error(codes.InvalidArgument, "failed to parse metadata")
}
auths, ok := md[":authority"]
if !ok {
return status.Error(codes.InvalidArgument, "no authority header")
}
if len(auths) != 1 {
return status.Errorf(codes.InvalidArgument, "expected exactly one authority header, got %v", auths)
}
if auths[0] != wantAuthority {
return status.Errorf(codes.InvalidArgument, "invalid authority header %q, want %q", auths[0], wantAuthority)
}
return nil
}

// Tests the `grpc.CallAuthority` option with TLS credentials. This test verifies
// that the provided authority is correctly propagated to the server when a
// correct authority is used.
func (s) TestCorrectAuthorityWithTLSCreds(t *testing.T) {
cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
if err != nil {
t.Fatalf("Failed to load key pair: %s", err)
}
creds, err := credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), "x.test.example.com")
if err != nil {
t.Fatalf("Failed to create credentials %v", err)
}
authority := "auth.test.example.com"
ss := &stubserver.StubServer{
EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
if err := authorityChecker(ctx, authority); err != nil {
return nil, err
}
return &testpb.Empty{}, nil
},
}
if err := ss.StartServer(grpc.Creds(credentials.NewServerTLSFromCert(&cert))); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()

cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(creds))
if err != nil {
t.Fatalf("grpc.NewClient(%q) = %v", ss.Address, err)
}
defer cc.Close()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

if _, err = testgrpc.NewTestServiceClient(cc).EmptyCall(ctx, &testpb.Empty{}, grpc.CallAuthority(authority)); status.Code(err) != codes.OK {
t.Fatalf("EmptyCall() returned status %v, want %v", status.Code(err), codes.OK)
}

}

// Tests the `grpc.CallAuthority` option with TLS credentials. This test verifies
// that the RPC fails with `UNAVAILABLE` status code and doesn't reach the server
// when an incorrect authority is used.
func (s) TestIncorrectAuthorityWithTLS(t *testing.T) {
cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
if err != nil {
t.Fatalf("Failed to load key pair: %s", err)
}
creds, err := credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), "x.test.example.com")
if err != nil {
t.Fatalf("Failed to create credentials %v", err)
}
authority := "auth.example.com"
serverCalled := make(chan struct{})
ss := &stubserver.StubServer{
EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
close(serverCalled)
return nil, nil
},
}
if err := ss.StartServer(grpc.Creds(credentials.NewServerTLSFromCert(&cert))); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()

cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(creds))
if err != nil {
t.Fatalf("grpc.NewClient(%q) = %v", ss.Address, err)
}
defer cc.Close()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

if _, err = testgrpc.NewTestServiceClient(cc).EmptyCall(ctx, &testpb.Empty{}, grpc.CallAuthority(authority)); status.Code(err) != codes.Unavailable {
t.Fatalf("EmptyCall() returned status %v, want %v", status.Code(err), codes.Unavailable)
}
select {
case <-serverCalled:
t.Fatalf("Server handler should not have been called")
case <-time.After(defaultTestShortTimeout):
}
}

// Tests the scenario where the `grpc.CallAuthority` call option is used with
// insecure transport credentials. The test verifies that the specified
// authority is correctly propagated to the server.
func (s) TestAuthorityCallOptionWithInsecureCreds(t *testing.T) {
authority := "test.server.name"

ss := &stubserver.StubServer{
EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
if err := authorityChecker(ctx, authority); err != nil {
return nil, err
}
return &testpb.Empty{}, nil
},
}
if err := ss.Start(nil); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()

cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("grpc.NewClient(%q) = %v", ss.Address, err)
}
defer cc.Close()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err = testgrpc.NewTestServiceClient(cc).EmptyCall(ctx, &testpb.Empty{}, grpc.CallAuthority(authority)); err != nil {
t.Fatalf("EmptyCall() rpc failed: %v", err)
}
}

// testAuthInfoNoValidator implements only credentials.AuthInfo and not
// credentials.AuthorityValidator.
type testAuthInfoNoValidator struct{}

// AuthType returns the authentication type.
func (testAuthInfoNoValidator) AuthType() string {
return "test"
}

// testAuthInfoWithValidator implements both credentials.AuthInfo and
// credentials.AuthorityValidator.
type testAuthInfoWithValidator struct {
validAuthority string
}

// AuthType returns the authentication type.
func (testAuthInfoWithValidator) AuthType() string {
return "test"
}

// ValidateAuthority implements credentials.AuthorityValidator.
func (v testAuthInfoWithValidator) ValidateAuthority(authority string) error {
if authority == v.validAuthority {
return nil
}
return fmt.Errorf("invalid authority")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: capture the invalid authority value in the returned error

}

// testCreds is a test TransportCredentials that can optionally support
// authority validation.
type testCreds struct {
WithValidator bool
Authority string
}

// ClientHandshake performs the client-side handshake.
func (c *testCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
if c.WithValidator {
return rawConn, testAuthInfoWithValidator{validAuthority: c.Authority}, nil
}
return rawConn, testAuthInfoNoValidator{}, nil
}

// ServerHandshake performs the server-side handshake.
func (c *testCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
if c.WithValidator {
return rawConn, testAuthInfoWithValidator{validAuthority: c.Authority}, nil
}
return rawConn, testAuthInfoNoValidator{}, nil
}

// Clone creates a copy of testCreds.
func (c *testCreds) Clone() credentials.TransportCredentials {
return &testCreds{WithValidator: c.WithValidator}
}

// Info provides protocol information.
func (c *testCreds) Info() credentials.ProtocolInfo {
return credentials.ProtocolInfo{}
}

// OverrideServerName overrides the server name used for verification.
func (c *testCreds) OverrideServerName(serverName string) error {
return nil
}

// TestAuthorityValidationFailureWithCustomCreds tests the `grpc.CallAuthority` call
// option using custom credentials. It verifies behavior both, when the
// credentials implement AuthorityValidator with incorrect authority override,
// as well as when the credentials do not implement AuthorityValidator. Both the
// cases are expected to fail with `UNAVAILABLE` status code.
func (s) TestAuthorityValidationFailureWithCustomCreds(t *testing.T) {
tests := []struct {
name string
creds credentials.TransportCredentials
authority string
wantStatus codes.Code
}{
{
name: "IncorrectAuthorityWithFakeCreds",
authority: "auth.example.com",
creds: &testCreds{WithValidator: true, Authority: "auth.test.example.com"},
wantStatus: codes.Unavailable,
},
{
name: "FakeCredsWithNoAuthValidator",
creds: &testCreds{WithValidator: false},
authority: "auth.test.example.com",
wantStatus: codes.Unavailable,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
serverCalled := make(chan struct{})
ss := stubserver.StubServer{
EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
close(serverCalled)
return nil, nil
},
}
if err := ss.StartServer(); err != nil {
t.Fatalf("Failed to start stub server: %v", err)
}
defer ss.Stop()

cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(tt.creds))
if err != nil {
t.Fatalf("grpc.NewClient(%q) = %v", ss.Address, err)
}
defer cc.Close()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err = testgrpc.NewTestServiceClient(cc).EmptyCall(ctx, &testpb.Empty{}, grpc.CallAuthority(tt.authority)); status.Code(err) != tt.wantStatus {
t.Fatalf("EmptyCall() returned status %v, want %v", status.Code(err), tt.wantStatus)
}
select {
case <-serverCalled:
t.Fatalf("Server should not have been called")
case <-time.After(defaultTestShortTimeout):
}
})
}

}

// TestCorrectAuthorityWithCustomCreds tests the `grpc.CallAuthority` call
// option using custom credentials. It verifies that the provided authority is
// correctly propagated to the server when a correct authority is used.
func (s) TestCorrectAuthorityWithCustomCreds(t *testing.T) {
authority := "auth.test.example.com"
creds := &testCreds{WithValidator: true, Authority: "auth.test.example.com"}
ss := stubserver.StubServer{
EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
if err := authorityChecker(ctx, authority); err != nil {
return nil, err
}
return &testpb.Empty{}, nil
},
}
if err := ss.StartServer(); err != nil {
t.Fatalf("Failed to start stub server: %v", err)
}
defer ss.Stop()

cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(creds))
if err != nil {
t.Fatalf("grpc.NewClient(%q) = %v", ss.Address, err)
}
defer cc.Close()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err = testgrpc.NewTestServiceClient(cc).EmptyCall(ctx, &testpb.Empty{}, grpc.CallAuthority(authority)); status.Code(err) != codes.OK {
t.Fatalf("EmptyCall() returned status %v, want %v", status.Code(err), codes.OK)
}
}
6 changes: 6 additions & 0 deletions credentials/insecure/insecure.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ func (info) AuthType() string {
return "insecure"
}

// ValidateAuthority allows any value to be overridden for the :authority
// header.
func (info) ValidateAuthority(string) error {
return nil
}

// insecureBundle implements an insecure bundle.
// An insecure bundle provides a thin wrapper around insecureTC to support
// the credentials.Bundle interface.
Expand Down
16 changes: 16 additions & 0 deletions credentials/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net"
"net/url"
Expand Down Expand Up @@ -50,6 +51,21 @@ func (t TLSInfo) AuthType() string {
return "tls"
}

// ValidateAuthority validates the provided authority being used to override the
// :authority header by verifying it against the peer certificates. It returns a
// non-nil error if the validation fails.
func (t TLSInfo) ValidateAuthority(authority string) error {
var errs []error
for _, cert := range t.State.PeerCertificates {
var err error
if err = cert.VerifyHostname(authority); err == nil {
return nil
}
errs = append(errs, err)
}
return fmt.Errorf("credentials: invalid authority %q: %v", authority, errors.Join(errs...))
}

// cipherSuiteLookup returns the string version of a TLS cipher suite ID.
func cipherSuiteLookup(cipherSuiteID uint16) string {
for _, s := range tls.CipherSuites() {
Expand Down
Loading