Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions credentials/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,21 @@ type AuthInfo interface {
AuthType() string
}

// AuthorityValidator validates the authority used to override the `:authority`
// header. This is an optional interface that implementations of AuthInfo can
// implement if the credentials support per-RPC authority overrides. It is
// invoked when the application attempts to override the HTTP/2 `:authority`
// header using the CallAuthority call option.
type AuthorityValidator interface {
// ValidateAuthority checks the authority value used to override the
// `:authority` header. The authority parameter is the override value
// provided by the application via the CallAuthority option. This value
// typically corresponds to the server hostname or endpoint the RPC is
// targeting. It returns nil if the validation succeeds, and a non-nil error
// if the validation fails.
ValidateAuthority(authority string) error
}

// ErrConnDispatched indicates that rawConn has been dispatched out of gRPC
// and the caller should not close rawConn.
var ErrConnDispatched = errors.New("credentials: rawConn is dispatched out of gRPC")
Expand Down
275 changes: 275 additions & 0 deletions credentials/credentials_ext_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
/*
*
* Copyright 2025 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package credentials_test

import (
"context"
"crypto/tls"
"fmt"
"log"
"net"
"testing"

"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/internal/stubserver"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/grpc/testdata"

testgrpc "google.golang.org/grpc/interop/grpc_testing"
testpb "google.golang.org/grpc/interop/grpc_testing"
)

var cert tls.Certificate
var creds credentials.TransportCredentials

func init() {
var err error
cert, err = tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
if err != nil {
log.Fatalf("failed to load key pair: %s", err)
}
creds, err = credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), "x.test.example.com")
if err != nil {
log.Fatalf("Failed to create credentials %v", err)
}
}

func authorityChecker(ctx context.Context, wantAuthority string) (*testpb.Empty, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, status.Error(codes.InvalidArgument, "failed to parse metadata")
}
auths, ok := md[":authority"]
if !ok {
return nil, status.Error(codes.InvalidArgument, "no authority header")
}
if len(auths) != 1 {
return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("no authority header, auths = %v", auths))
}
if auths[0] != wantAuthority {
return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("invalid authority header %v, want %v", auths[0], wantAuthority))
}
return &testpb.Empty{}, nil
}

// Tests the grpc.CallAuthority option with TLS credentials. This test verifies
// that the provided authority is correctly propagated to the server when using
// TLS. It covers both positive and negative cases: correct authority and
// incorrect authority, expecting the RPC to fail with `UNAVAILABLE` status code
// error in the later case.
func TestAuthorityCallOptionsWithTLSCreds(t *testing.T) {
tests := []struct {
name string
wantAuth string
wantStatus codes.Code
}{
{
name: "CorrectAuthority",
wantAuth: "auth.test.example.com",
wantStatus: codes.OK,
},
{
name: "IncorrectAuthority",
wantAuth: "auth.example.com",
wantStatus: codes.Unavailable,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ss := &stubserver.StubServer{
EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
return authorityChecker(ctx, tt.wantAuth)
},
}
if err := ss.StartServer(grpc.Creds(credentials.NewServerTLSFromCert(&cert))); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()

cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(creds))
if err != nil {
t.Fatalf("grpc.NewClient(%q) = %v", ss.Address, err)
}
defer cc.Close()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

if _, err = testgrpc.NewTestServiceClient(cc).EmptyCall(ctx, &testpb.Empty{}, grpc.CallAuthority(tt.wantAuth)); status.Code(err) != tt.wantStatus {
t.Fatalf("EmptyCall() returned status %v, want %v", status.Code(err), tt.wantStatus)
}
})
}
}

// Tests the scenario where the `grpc.CallAuthority` call option is used with
// insecure transport credentials. The test verifies that the specified
// authority is correctly propagated to the server.
func (s) TestAuthorityCallOptionWithInsecureCreds(t *testing.T) {
const wantAuthority = "test.server.name"

ss := &stubserver.StubServer{
EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
return authorityChecker(ctx, wantAuthority)
},
}
if err := ss.Start(nil); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()

cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("grpc.NewClient(%q) = %v", ss.Address, err)
}
defer cc.Close()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err = testgrpc.NewTestServiceClient(cc).EmptyCall(ctx, &testpb.Empty{}, grpc.CallAuthority(wantAuthority)); err != nil {
t.Fatalf("EmptyCall() rpc failed: %v", err)
}
}

// testAuthInfoNoValidator implements only credentials.AuthInfo and not
// credentials.AuthorityValidator.
type testAuthInfoNoValidator struct{}

// AuthType returns the authentication type.
func (testAuthInfoNoValidator) AuthType() string {
return "test"
}

// testAuthInfoWithValidator implements both credentials.AuthInfo and
// credentials.AuthorityValidator.
type testAuthInfoWithValidator struct{}

// AuthType returns the authentication type.
func (testAuthInfoWithValidator) AuthType() string {
return "test"
}

// ValidateAuthority implements credentials.AuthorityValidator.
func (testAuthInfoWithValidator) ValidateAuthority(authority string) error {
if authority == "auth.test.example.com" {
return nil
}
return fmt.Errorf("invalid authority")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: capture the invalid authority value in the returned error

}

// testCreds is a test TransportCredentials that can optionally support
// authority validation.
type testCreds struct {
WithValidator bool
}

// ClientHandshake performs the client-side handshake.
func (c *testCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
if c.WithValidator {
return rawConn, testAuthInfoWithValidator{}, nil
}
return rawConn, testAuthInfoNoValidator{}, nil
}

// ServerHandshake performs the server-side handshake.
func (c *testCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
if c.WithValidator {
return rawConn, testAuthInfoWithValidator{}, nil
}
return rawConn, testAuthInfoNoValidator{}, nil
}

// Clone creates a copy of testCreds.
func (c *testCreds) Clone() credentials.TransportCredentials {
return &testCreds{WithValidator: c.WithValidator}
}

// Info provides protocol information.
func (c *testCreds) Info() credentials.ProtocolInfo {
return credentials.ProtocolInfo{}
}

// OverrideServerName overrides the server name used for verification.
func (c *testCreds) OverrideServerName(serverName string) error {
return nil
}

// TestCorrectAuthorityWithCustomCreds tests the `grpc.CallAuthority` call
// option using custom credentials. It verifies behavior both, when the
// credentials implement AuthorityValidator with both correct and incorrect
// authority overrides, as well as when the credentials do not implement
// AuthorityValidator. The later two cases, i.e when the credentials do not
// implement AuthorityValidator, and the authority used to override is invalid,
// are expected to fail with `UNAVAILABLE` status code.
func (s) TestCorrectAuthorityWithCustomCreds(t *testing.T) {
tests := []struct {
name string
creds credentials.TransportCredentials
wantAuth string
wantStatus codes.Code
}{
{
name: "CorrectAuthorityWithFakeCreds",
wantAuth: "auth.test.example.com",
creds: &testCreds{WithValidator: true},
wantStatus: codes.OK,
},
{
name: "IncorrectAuthorityWithFakeCreds",
wantAuth: "auth.example.com",
creds: &testCreds{WithValidator: true},
wantStatus: codes.Unavailable,
},
{
name: "FakeCredsWithNoAuthValidator",
creds: &testCreds{WithValidator: false},
wantAuth: "auth.test.example.com",
wantStatus: codes.Unavailable,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ss := stubserver.StubServer{
EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
return authorityChecker(ctx, tt.wantAuth)
},
}
if err := ss.StartServer(); err != nil {
t.Fatalf("Failed to start stub server: %v", err)
}
defer ss.Stop()

cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(tt.creds))
if err != nil {
t.Fatalf("grpc.NewClient(%q) = %v", ss.Address, err)
}
defer cc.Close()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err = testgrpc.NewTestServiceClient(cc).EmptyCall(ctx, &testpb.Empty{}, grpc.CallAuthority(tt.wantAuth)); status.Code(err) != tt.wantStatus {
t.Fatalf("EmptyCall() returned status %v, want %v", status.Code(err), tt.wantStatus)
}
})
}
}
4 changes: 4 additions & 0 deletions credentials/insecure/insecure.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ func (info) AuthType() string {
return "insecure"
}

func (info) ValidateAuthority(string) error {
return nil
}

// insecureBundle implements an insecure bundle.
// An insecure bundle provides a thin wrapper around insecureTC to support
// the credentials.Bundle interface.
Expand Down
17 changes: 17 additions & 0 deletions credentials/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net"
"net/url"
Expand Down Expand Up @@ -50,6 +51,22 @@ func (t TLSInfo) AuthType() string {
return "tls"
}

// ValidateAuthority validates that the provided authority being used to
// override the :authority header is valid by verifying it against the peer
// certificates. It returns nil on successful validation, or a non-nil error if
// the validation fails.
func (t TLSInfo) ValidateAuthority(authority string) error {
var errs []error
for _, cert := range t.State.PeerCertificates {
var err error
if err = cert.VerifyHostname(authority); err == nil {
return nil
}
errs = append(errs, err)
}
return fmt.Errorf("credentials: invalid authority %q: %w", authority, errors.Join(errs...))
}

// cipherSuiteLookup returns the string version of a TLS cipher suite ID.
func cipherSuiteLookup(cipherSuiteID uint16) string {
for _, s := range tls.CipherSuites() {
Expand Down
19 changes: 19 additions & 0 deletions internal/transport/http2_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,25 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*ClientS
callHdr = &newCallHdr
}

// The authority specified via the `CallAuthority` CallOption takes the
// highest precedence when determining the `:authority` header. It overrides
// any value present in the Host field of CallHdr. Before applying this
// override, the authority string is validated. If the credentials do not
// implement the AuthorityValidator interface, or if validation fails, the
// RPC is failed with a status code of `UNAVAILABLE`.
if callHdr.Authority != "" {
auth, ok := t.authInfo.(credentials.AuthorityValidator)
if !ok {
return nil, &NewStreamError{Err: status.Error(codes.Unavailable, fmt.Sprintf("credentials type %s does not implement the AuthorityValidator interface, but authority override specified with CallAuthority call option", t.authInfo.AuthType()))}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here too. Use status.Errorf.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And use %q formatting directive for the credentials type.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

}
if err := auth.ValidateAuthority(callHdr.Authority); err != nil {
return nil, &NewStreamError{Err: status.Error(codes.Unavailable, fmt.Sprintf("failed to validate authority %s : %s", callHdr.Authority, err))}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both the above comments here too. And use %v for err.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

}
newCallHdr := *callHdr
newCallHdr.Host = callHdr.Authority
callHdr = &newCallHdr
}

headerFields, err := t.createHeaderFields(ctx, callHdr)
if err != nil {
return nil, &NewStreamError{Err: err, AllowTransparentRetry: false}
Expand Down
5 changes: 5 additions & 0 deletions internal/transport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,11 @@ type CallHdr struct {
PreviousAttempts int // value of grpc-previous-rpc-attempts header to set

DoneFunc func() // called when the stream is finished

// Authority is used to explicitly override the `:authority` header. If set,
// this value takes precedence over the Host field and will be used as the
// value for the `:authority` header.
Authority string
}

// ClientTransport is the common interface for all gRPC client-side transport
Expand Down
Loading