Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions credentials/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,14 @@ type AuthInfo interface {
AuthType() string
}

// AuthorityValidator defines an interface for validating the authority used to
// override the `:authority` header. A struct implementing AuthInfo should also
// implement AuthorityValidator if the credentials need to support per-RPC
// authority overrides.
type AuthorityValidator interface {
ValidateAuthority(authority string) error
}

// ErrConnDispatched indicates that rawConn has been dispatched out of gRPC
// and the caller should not close rawConn.
var ErrConnDispatched = errors.New("credentials: rawConn is dispatched out of gRPC")
Expand Down
365 changes: 365 additions & 0 deletions credentials/credentials_ext_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,365 @@
/*
*
* Copyright 2025 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package credentials_test

import (
"context"
"crypto/tls"
"fmt"
"log"
"net"
"testing"

"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/internal/stubserver"
testgrpc "google.golang.org/grpc/interop/grpc_testing"
testpb "google.golang.org/grpc/interop/grpc_testing"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/grpc/testdata"
)

func authorityChecker(ctx context.Context, expectedAuthority string) (*testpb.Empty, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, status.Error(codes.InvalidArgument, "failed to parse metadata")
}
auths, ok := md[":authority"]
if !ok {
return nil, status.Error(codes.InvalidArgument, "no authority header")
}
if len(auths) != 1 {
return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("no authority header, auths = %v", auths))
}
if auths[0] != expectedAuthority {
return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("invalid authority header %v, expected %v", auths[0], expectedAuthority))
}
return &testpb.Empty{}, nil
}

func checkUnavailableRPCError(t *testing.T, err error) {
t.Helper()
if err == nil {
t.Fatalf("EmptyCall() should fail")
}
s, ok := status.FromError(err)
if !ok {
t.Fatalf("unexpected error: %v", err)
}
if s.Code() != codes.Unavailable {
t.Fatalf("EmptyCall() = _, %v, want _, error code: %v", s.Code(), codes.Unavailable)
}
}

// Tests the grpc.CallAuthority option with TLS credentials. This test verifies
// that the provided authority is correctly propagated to the server when using TLS.
// It covers both positive and negative cases: correct authority and incorrect
// authority, expecting the RPC to fail with `UNAVAILABLE` status code error in
// the latter case.
func (s) TestAuthorityCallOptionsWithTLSCreds(t *testing.T) {
tests := []struct {
name string
expectedAuth string
expectRPCError bool
}{
{
name: "CorrectAuthorityWithTLS",
expectedAuth: "auth.test.example.com",
expectRPCError: false,
},
{
name: "IncorrectAuthorityWithTLS",
expectedAuth: "auth.example.com",
expectRPCError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
if err != nil {
log.Fatalf("failed to load key pair: %s", err)
}
ss := &stubserver.StubServer{
EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
return authorityChecker(ctx, tt.expectedAuth)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm.. I'm not convinced about checking the authority header on the server handler, because when the validation fails, the RPC will not even make it to the server. Maybe, checking from the server handler is OK when you actually expect validation to succeed on the client and expect the RPC to reach the server.

You know the certs you are using for the server. So, you can specify an authority override on the client that you expect to work and one that you dont expect to work, because it will fail validation with the peer certificate.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure what you mean , when the authority is not correct , we want the RPC call to return UNAVAILABLE error in client and that is what we are checking. And when it passes the validation , we want it to correctly reach the server and check if the correct authority has reached. We do not expect to check authority on server even when it is wrong or fails validation?

},
}
if err := ss.StartServer(grpc.Creds(credentials.NewServerTLSFromCert(&cert))); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()
creds, err := credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), "x.test.example.com")
if err != nil {
t.Fatalf("Failed to create credentials %v", err)
}

cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(creds))
if err != nil {
t.Fatalf("grpc.NewClient(%q) = %v", ss.Address, err)
}
defer cc.Close()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

_, err = testgrpc.NewTestServiceClient(cc).EmptyCall(ctx, &testpb.Empty{}, grpc.CallAuthority(tt.expectedAuth))
if tt.expectRPCError {
checkUnavailableRPCError(t, err)
} else if err != nil {
t.Fatalf("EmptyCall() rpc failed: %v", err)
}
})
}
}

func (s) TestTLSCredsWithNoAuthorityOverride(t *testing.T) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a test comment for this.

What scenario is this testing? If this test fails, what does it indicate about the authority override feature?

cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
if err != nil {
log.Fatalf("failed to load key pair: %s", err)
}
ss := &stubserver.StubServer{
EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
return authorityChecker(ctx, "x.test.example.com")
},
}
if err := ss.StartServer(grpc.Creds(credentials.NewServerTLSFromCert(&cert))); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()
creds, err := credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), "x.test.example.com")
if err != nil {
t.Fatalf("Failed to create credentials %v", err)
}

cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(creds))
if err != nil {
t.Fatalf("grpc.NewClient(%q) = %v", ss.Address, err)
}
defer cc.Close()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

_, err = testgrpc.NewTestServiceClient(cc).EmptyCall(ctx, &testpb.Empty{})
if err != nil {
t.Fatalf("EmptyCall() rpc failed: %v", err)
}
}

// Tests the scenario where grpc.CallAuthority option is used with insecure credentials.
// The test verifies that the CallAuthority option is correctly passed even when
// insecure credentials are used.
func (s) TestAuthorityCallOptionWithInsecureCreds(t *testing.T) {
const expectedAuthority = "test.server.name"

ss := &stubserver.StubServer{
EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
return authorityChecker(ctx, expectedAuthority)
},
}
if err := ss.Start(nil); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()

cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("grpc.NewClient(%q) = %v", ss.Address, err)
}
defer cc.Close()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err = testgrpc.NewTestServiceClient(cc).EmptyCall(ctx, &testpb.Empty{}, grpc.CallAuthority(expectedAuthority)); err != nil {
t.Fatalf("EmptyCall() rpc failed: %v", err)
}
}

// FakeCredsNoAuthValidator is a test credential that does not implement AuthorityValidator.
type FakeCredsNoAuthValidator struct {
}

// ClientHandshake performs the client-side handshake.
func (c *FakeCredsNoAuthValidator) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
return rawConn, TestAuthInfo{}, nil
}

// TestAuthInfo implements the AuthInfo interface.
type TestAuthInfo struct{}

// AuthType returns the authentication type.
func (TestAuthInfo) AuthType() string { return "test" }

// Clone creates a copy of FakeCredsNoAuthValidator.
func (c *FakeCredsNoAuthValidator) Clone() credentials.TransportCredentials {
return c
}

// Info provides protocol information.
func (c *FakeCredsNoAuthValidator) Info() credentials.ProtocolInfo {
return credentials.ProtocolInfo{}
}

// OverrideServerName overrides the server name used for verification.
func (c *FakeCredsNoAuthValidator) OverrideServerName(serverName string) error {
return nil
}

// ServerHandshake performs the server-side handshake.
// Returns a test AuthInfo object to satisfy the interface requirements.
func (c *FakeCredsNoAuthValidator) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
return rawConn, TestAuthInfo{}, nil
}

// TestCallOptionWithNoAuthorityValidator tests the CallAuthority call option
// with custom credentials that do not implement AuthorityValidator and verifies
// that it fails with `UNAVAILABLE` status code.
func (s) TestCallOptionWithNoAuthorityValidator(t *testing.T) {
const expectedAuthority = "auth.test.example.com"

// Initialize a stub server with a basic handler.
ss := stubserver.StubServer{
EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
return &testpb.Empty{}, nil
},
}
if err := ss.StartServer(); err != nil {
t.Fatalf("Failed to start stub server: %v", err)
}
defer ss.Stop()

// Create a gRPC client connection with FakeCredsNoAuthValidator.
clientConn, err := grpc.NewClient(ss.Address,
grpc.WithTransportCredentials(&FakeCredsNoAuthValidator{}))
if err != nil {
t.Fatalf("Failed to create gRPC client connection: %v", err)
}
defer clientConn.Close()

// Perform a test RPC with a specified call authority.
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

_, err = testgrpc.NewTestServiceClient(clientConn).EmptyCall(ctx, &testpb.Empty{}, grpc.CallAuthority(expectedAuthority))

// Verify that the RPC fails with an UNAVAILABLE error.
checkUnavailableRPCError(t, err)
}

// FakeCredsWithAuthValidator is a test credential that does not implement AuthorityValidator.
type FakeCredsWithAuthValidator struct {
}

// ClientHandshake performs the client-side handshake.
func (c *FakeCredsWithAuthValidator) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
return rawConn, FakeAuthInfo{}, nil
}

// TestAuthInfo implements the AuthInfo interface.
type FakeAuthInfo struct{}

// AuthType returns the authentication type.
func (FakeAuthInfo) AuthType() string { return "test" }

// AuthType returns the authentication type.
func (FakeAuthInfo) ValidateAuthority(authority string) error {
if authority == "auth.test.example.com" {
return nil
} else {
return fmt.Errorf("invalid authority")
}
}

// Clone creates a copy of FakeCredsWithAuthValidator.
func (c *FakeCredsWithAuthValidator) Clone() credentials.TransportCredentials {
return c
}

// Info provides protocol information.
func (c *FakeCredsWithAuthValidator) Info() credentials.ProtocolInfo {
return credentials.ProtocolInfo{}
}

// OverrideServerName overrides the server name used for verification.
func (c *FakeCredsWithAuthValidator) OverrideServerName(serverName string) error {
return nil
}

// ServerHandshake performs the server-side handshake.
// Returns a test AuthInfo object to satisfy the interface requirements.
func (c *FakeCredsWithAuthValidator) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
return rawConn, FakeAuthInfo{}, nil
}

// TestCorrectAuthorityWithCustomCreds tests the CallAuthority call option
// with custom credentials that implement AuthorityValidator and verifies
// it with both correct and incorrect authority override.
func (s) TestCorrectAuthorityWithCustomCreds(t *testing.T) {
tests := []struct {
name string
expectedAuth string
expectRPCError bool
}{
{
name: "CorrectAuthorityWithFakeCreds",
expectedAuth: "auth.test.example.com",
expectRPCError: false,
},
{
name: "IncorrectAuthorityWithFakeCreds",
expectedAuth: "auth.example.com",
expectRPCError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ss := stubserver.StubServer{
EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
return authorityChecker(ctx, tt.expectedAuth)
},
}
if err := ss.StartServer(); err != nil {
t.Fatalf("Failed to start stub server: %v", err)
}
defer ss.Stop()

// Create a gRPC client connection with FakeCredsWithAuthValidator.
clientConn, err := grpc.NewClient(ss.Address,
grpc.WithTransportCredentials(&FakeCredsWithAuthValidator{}))
if err != nil {
t.Fatalf("Failed to create gRPC client connection: %v", err)
}
defer clientConn.Close()

// Perform a test RPC with a specified call authority.
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

_, err = testgrpc.NewTestServiceClient(clientConn).EmptyCall(ctx, &testpb.Empty{}, grpc.CallAuthority(tt.expectedAuth))
if tt.expectRPCError {
checkUnavailableRPCError(t, err)
} else if err != nil {
t.Fatalf("EmptyCall() rpc failed: %v", err)
}
})
}
}
4 changes: 4 additions & 0 deletions credentials/insecure/insecure.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ func (info) AuthType() string {
return "insecure"
}

func (info) ValidateAuthority(_ string) error {
return nil
}

// insecureBundle implements an insecure bundle.
// An insecure bundle provides a thin wrapper around insecureTC to support
// the credentials.Bundle interface.
Expand Down
Loading