Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions internal/xds/bootstrap/bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ package bootstrap

import (
"bytes"
"context"
"encoding/json"
"fmt"
"maps"
"net"
"net/url"
"os"
"slices"
Expand Down Expand Up @@ -179,6 +181,7 @@ type ServerConfig struct {
// credentials and store it here for easy access.
selectedCreds ChannelCreds
credsDialOption grpc.DialOption
dialerOption grpc.DialOption

cleanups []func()
}
Expand Down Expand Up @@ -223,6 +226,12 @@ func (sc *ServerConfig) CredsDialOption() grpc.DialOption {
return sc.credsDialOption
}

// DialerOption returns the first supported Dialer function that specifies how
// to dial the xDS server from the configuration, as a dial option.
func (sc *ServerConfig) DialerOption() grpc.DialOption {
return sc.dialerOption
}

// Cleanups returns a collection of functions to be called when the xDS client
// for this server is closed. Allows cleaning up resources created specifically
// for this server.
Expand Down Expand Up @@ -275,6 +284,12 @@ func (sc *ServerConfig) MarshalJSON() ([]byte, error) {
return json.Marshal(server)
}

// dialer captures the Dialer method specified via the credentials bundle.
type dialer interface {
// Dialer specifies how to dial the xDS server.
Dialer(context.Context, string) (net.Conn, error)
}

// UnmarshalJSON takes the json data (a server) and unmarshals it to the struct.
func (sc *ServerConfig) UnmarshalJSON(data []byte) error {
server := serverConfigJSON{}
Expand All @@ -298,6 +313,10 @@ func (sc *ServerConfig) UnmarshalJSON(data []byte) error {
}
sc.selectedCreds = cc
sc.credsDialOption = grpc.WithCredentialsBundle(bundle)
d, ok := bundle.(dialer)
if ok {
sc.dialerOption = grpc.WithContextDialer(d.Dialer)
}
sc.cleanups = append(sc.cleanups, cancel)
break
}
Expand Down
4 changes: 4 additions & 0 deletions xds/internal/xdsclient/transport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,10 @@ func New(opts Options) (*Transport, error) {
Timeout: 20 * time.Second,
}),
}
dialerOpts := opts.ServerCfg.DialerOption()
if dialerOpts != nil {
dopts = append(dopts, dialerOpts)
}
grpcNewClient := transportinternal.GRPCNewClient.(func(string, ...grpc.DialOption) (*grpc.ClientConn, error))
cc, err := grpcNewClient(opts.ServerCfg.ServerURI(), dopts...)
if err != nil {
Expand Down
83 changes: 81 additions & 2 deletions xds/internal/xdsclient/transport/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,62 @@
package transport_test

import (
"context"
"encoding/json"
"net"
"testing"

"google.golang.org/grpc"
"google.golang.org/grpc/internal/xds/bootstrap"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
internalbootstrap "google.golang.org/grpc/internal/xds/bootstrap"
"google.golang.org/grpc/xds/bootstrap"
"google.golang.org/grpc/xds/internal/xdsclient/transport"
"google.golang.org/grpc/xds/internal/xdsclient/transport/internal"

v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
)

const testDialerCredsBuilderName = "test_dialer_creds"

func init() {
bootstrap.RegisterCredentials(&testDialerCredsBuilder{})
}

// testDialerCredsBuilder implements the `Credentials` interface defined in
// package `xds/bootstrap` and encapsulates an insecure credential with a
// custom Dialer that specifies how to dial the xDS server.
type testDialerCredsBuilder struct{}

func (t *testDialerCredsBuilder) Build(json.RawMessage) (credentials.Bundle, func(), error) {
return &testDialerCredsBundle{}, func() {}, nil
}

func (t *testDialerCredsBuilder) Name() string {
return testDialerCredsBuilderName
}

// testDialerCredsBundle implements the `Bundle` interface defined in package
// `credentials` and encapsulates an insecure credential with a custom Dialer
// that specifies how to dial the xDS server.
type testDialerCredsBundle struct{}

func (t *testDialerCredsBundle) TransportCredentials() credentials.TransportCredentials {
return insecure.NewCredentials()
}

func (t *testDialerCredsBundle) PerRPCCredentials() credentials.PerRPCCredentials {
return nil
}

func (t *testDialerCredsBundle) NewWithMode(string) (credentials.Bundle, error) {
return &testDialerCredsBundle{}, nil
}

func (t *testDialerCredsBundle) Dialer(context.Context, string) (net.Conn, error) {
return nil, nil
}

func (s) TestNewWithGRPCDial(t *testing.T) {
// Override the dialer with a custom one.
customDialerCalled := false
Expand All @@ -39,7 +85,7 @@ func (s) TestNewWithGRPCDial(t *testing.T) {
internal.GRPCNewClient = customDialer
defer func() { internal.GRPCNewClient = oldDial }()

serverCfg, err := bootstrap.ServerConfigForTesting(bootstrap.ServerConfigTestingOptions{URI: "server-address"})
serverCfg, err := internalbootstrap.ServerConfigForTesting(internalbootstrap.ServerConfigTestingOptions{URI: "server-address"})
if err != nil {
t.Fatalf("Failed to create server config for testing: %v", err)
}
Expand Down Expand Up @@ -82,3 +128,36 @@ func (s) TestNewWithGRPCDial(t *testing.T) {
t.Fatalf("transport.New(%+v) custom dialer called = true, want false", opts)
}
}

func (s) TestNewWithDialerFromCredentialsBundle(t *testing.T) {
serverCfg, err := internalbootstrap.ServerConfigForTesting(internalbootstrap.ServerConfigTestingOptions{
URI: "trafficdirector.googleapis.com:443",
ChannelCreds: []internalbootstrap.ChannelCreds{{Type: testDialerCredsBuilderName}},
})
if err != nil {
t.Fatalf("Failed to create server config for testing: %v", err)
}
if serverCfg.DialerOption() == nil {
t.Fatalf("Dialer for xDS transport in server config for testing is nil, want non-nil")
}
// Create a new transport.
opts := transport.Options{
ServerCfg: serverCfg,
NodeProto: &v3corepb.Node{},
OnRecvHandler: func(update transport.ResourceUpdate, onDone func()) error {
onDone()
return nil
},
OnErrorHandler: func(error) {},
OnSendHandler: func(*transport.ResourceSendInfo) {},
}
c, err := transport.New(opts)
defer func() {
if c != nil {
c.Close()
}
}()
if err != nil {
t.Fatalf("transport.New(%v) failed: %v", opts, err)
}
}