Skip to content

Commit 59c0aec

Browse files
authored
xDS: Atomically read and write xDS security configuration client side (#6796)
1 parent ce3b538 commit 59c0aec

File tree

9 files changed

+76
-111
lines changed

9 files changed

+76
-111
lines changed

credentials/xds/xds.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import (
2727
"errors"
2828
"fmt"
2929
"net"
30+
"sync/atomic"
3031
"time"
3132

3233
"google.golang.org/grpc/credentials"
@@ -114,7 +115,9 @@ func (c *credsImpl) ClientHandshake(ctx context.Context, authority string, rawCo
114115
if chi.Attributes == nil {
115116
return c.fallback.ClientHandshake(ctx, authority, rawConn)
116117
}
117-
hi := xdsinternal.GetHandshakeInfo(chi.Attributes)
118+
119+
uPtr := xdsinternal.GetHandshakeInfo(chi.Attributes)
120+
hi := (*xdsinternal.HandshakeInfo)(atomic.LoadPointer(uPtr))
118121
if hi.UseFallbackCreds() {
119122
return c.fallback.ClientHandshake(ctx, authority, rawConn)
120123
}

credentials/xds/xds_client_test.go

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@ import (
2727
"net"
2828
"os"
2929
"strings"
30+
"sync/atomic"
3031
"testing"
3132
"time"
33+
"unsafe"
3234

3335
"google.golang.org/grpc/credentials"
3436
"google.golang.org/grpc/credentials/tls/certprovider"
@@ -219,11 +221,13 @@ func newTestContextWithHandshakeInfo(parent context.Context, root, identity cert
219221
// Creating the HandshakeInfo and adding it to the attributes is very
220222
// similar to what the CDS balancer would do when it intercepts calls to
221223
// NewSubConn().
222-
info := xdsinternal.NewHandshakeInfo(root, identity)
224+
var sms []matcher.StringMatcher
223225
if sanExactMatch != "" {
224-
info.SetSANMatchers([]matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(sanExactMatch), nil, nil, nil, nil, false)})
226+
sms = []matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(sanExactMatch), nil, nil, nil, nil, false)}
225227
}
226-
addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, info)
228+
info := xdsinternal.NewHandshakeInfo(root, identity, sms, false)
229+
uPtr := unsafe.Pointer(info)
230+
addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, &uPtr)
227231

228232
// Moving the attributes from the resolver.Address to the context passed to
229233
// the handshaker is done in the transport layer. Since we directly call the
@@ -533,13 +537,12 @@ func (s) TestClientCredsProviderSwitch(t *testing.T) {
533537
// Create a root provider which will fail the handshake because it does not
534538
// use the correct trust roots.
535539
root1 := makeRootProvider(t, "x509/client_ca_cert.pem")
536-
handshakeInfo := xdsinternal.NewHandshakeInfo(root1, nil)
537-
handshakeInfo.SetSANMatchers([]matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(defaultTestCertSAN), nil, nil, nil, nil, false)})
538-
540+
handshakeInfo := xdsinternal.NewHandshakeInfo(root1, nil, []matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(defaultTestCertSAN), nil, nil, nil, nil, false)}, false)
539541
// We need to repeat most of what newTestContextWithHandshakeInfo() does
540542
// here because we need access to the underlying HandshakeInfo so that we
541543
// can update it before the next call to ClientHandshake().
542-
addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, handshakeInfo)
544+
uPtr := unsafe.Pointer(handshakeInfo)
545+
addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, &uPtr)
543546
ctx = icredentials.NewClientHandshakeInfoContext(ctx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes})
544547
if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil {
545548
t.Fatal("ClientHandshake() succeeded when expected to fail")
@@ -560,7 +563,10 @@ func (s) TestClientCredsProviderSwitch(t *testing.T) {
560563
// Create a new root provider which uses the correct trust roots. And update
561564
// the HandshakeInfo with the new provider.
562565
root2 := makeRootProvider(t, "x509/server_ca_cert.pem")
563-
handshakeInfo.SetRootCertProvider(root2)
566+
handshakeInfo = xdsinternal.NewHandshakeInfo(root2, nil, []matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(defaultTestCertSAN), nil, nil, nil, nil, false)}, false)
567+
// Update the existing pointer, which address attribute will continue to
568+
// point to.
569+
atomic.StorePointer(&uPtr, unsafe.Pointer(handshakeInfo))
564570
_, ai, err := creds.ClientHandshake(ctx, authority, conn)
565571
if err != nil {
566572
t.Fatalf("ClientHandshake() returned failed: %q", err)

credentials/xds/xds_server_test.go

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ func (s) TestServerCredsInvalidHandshakeInfo(t *testing.T) {
122122
t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err)
123123
}
124124

125-
info := xdsinternal.NewHandshakeInfo(&fakeProvider{}, nil)
125+
info := xdsinternal.NewHandshakeInfo(&fakeProvider{}, nil, nil, false)
126126
conn := newWrappedConn(nil, info, time.Time{})
127127
if _, _, err := creds.ServerHandshake(conn); err == nil {
128128
t.Fatal("ServerHandshake succeeded without identity certificate provider in HandshakeInfo")
@@ -158,7 +158,7 @@ func (s) TestServerCredsProviderFailure(t *testing.T) {
158158
}
159159
for _, test := range tests {
160160
t.Run(test.desc, func(t *testing.T) {
161-
info := xdsinternal.NewHandshakeInfo(test.rootProvider, test.identityProvider)
161+
info := xdsinternal.NewHandshakeInfo(test.rootProvider, test.identityProvider, nil, false)
162162
conn := newWrappedConn(nil, info, time.Time{})
163163
if _, _, err := creds.ServerHandshake(conn); err == nil || !strings.Contains(err.Error(), test.wantErr) {
164164
t.Fatalf("ServerHandshake() returned error: %q, wantErr: %q", err, test.wantErr)
@@ -232,8 +232,7 @@ func (s) TestServerCredsHandshakeTimeout(t *testing.T) {
232232
// Create a test server which uses the xDS server credentials created above
233233
// to perform TLS handshake on incoming connections.
234234
ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult {
235-
hi := xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/client_ca_cert.pem"), makeIdentityProvider(t, "x509/server2_cert.pem", "x509/server2_key.pem"))
236-
hi.SetRequireClientCert(true)
235+
hi := xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/client_ca_cert.pem"), makeIdentityProvider(t, "x509/server2_cert.pem", "x509/server2_key.pem"), nil, true)
237236

238237
// Create a wrapped conn which can return the HandshakeInfo created
239238
// above with a very small deadline.
@@ -285,8 +284,7 @@ func (s) TestServerCredsHandshakeFailure(t *testing.T) {
285284
ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult {
286285
// Create a HandshakeInfo which has a root provider which does not match
287286
// the certificate sent by the client.
288-
hi := xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/client2_cert.pem", "x509/client2_key.pem"))
289-
hi.SetRequireClientCert(true)
287+
hi := xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/client2_cert.pem", "x509/client2_key.pem"), nil, true)
290288

291289
// Create a wrapped conn which can return the HandshakeInfo and
292290
// configured deadline to the xDS credentials' ServerHandshake()
@@ -367,8 +365,7 @@ func (s) TestServerCredsHandshakeSuccess(t *testing.T) {
367365
// created above to perform TLS handshake on incoming connections.
368366
ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult {
369367
// Create a HandshakeInfo with information from the test table.
370-
hi := xdsinternal.NewHandshakeInfo(test.rootProvider, test.identityProvider)
371-
hi.SetRequireClientCert(test.requireClientCert)
368+
hi := xdsinternal.NewHandshakeInfo(test.rootProvider, test.identityProvider, nil, test.requireClientCert)
372369

373370
// Create a wrapped conn which can return the HandshakeInfo and
374371
// configured deadline to the xDS credentials' ServerHandshake()
@@ -448,8 +445,7 @@ func (s) TestServerCredsProviderSwitch(t *testing.T) {
448445
if cnt == 1 {
449446
// Create a HandshakeInfo which has a root provider which does not match
450447
// the certificate sent by the client.
451-
hi = xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/client2_cert.pem", "x509/client2_key.pem"))
452-
hi.SetRequireClientCert(true)
448+
hi = xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/client2_cert.pem", "x509/client2_key.pem"), nil, true)
453449

454450
// Create a wrapped conn which can return the HandshakeInfo and
455451
// configured deadline to the xDS credentials' ServerHandshake()
@@ -463,8 +459,7 @@ func (s) TestServerCredsProviderSwitch(t *testing.T) {
463459
return handshakeResult{}
464460
}
465461

466-
hi = xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/client_ca_cert.pem"), makeIdentityProvider(t, "x509/server1_cert.pem", "x509/server1_key.pem"))
467-
hi.SetRequireClientCert(true)
462+
hi = xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/client_ca_cert.pem"), makeIdentityProvider(t, "x509/server1_cert.pem", "x509/server1_key.pem"), nil, true)
468463

469464
// Create a wrapped conn which can return the HandshakeInfo and
470465
// configured deadline to the xDS credentials' ServerHandshake()

internal/credentials/xds/handshake_info.go

Lines changed: 16 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import (
2626
"errors"
2727
"fmt"
2828
"strings"
29-
"sync"
29+
"unsafe"
3030

3131
"google.golang.org/grpc/attributes"
3232
"google.golang.org/grpc/credentials/tls/certprovider"
@@ -66,59 +66,38 @@ func (hi *HandshakeInfo) Equal(other *HandshakeInfo) bool {
6666
}
6767

6868
// SetHandshakeInfo returns a copy of addr in which the Attributes field is
69-
// updated with hInfo.
70-
func SetHandshakeInfo(addr resolver.Address, hInfo *HandshakeInfo) resolver.Address {
71-
addr.Attributes = addr.Attributes.WithValue(handshakeAttrKey{}, hInfo)
69+
// updated with hiPtr.
70+
func SetHandshakeInfo(addr resolver.Address, hiPtr *unsafe.Pointer) resolver.Address {
71+
addr.Attributes = addr.Attributes.WithValue(handshakeAttrKey{}, hiPtr)
7272
return addr
7373
}
7474

75-
// GetHandshakeInfo returns a pointer to the HandshakeInfo stored in attr.
76-
func GetHandshakeInfo(attr *attributes.Attributes) *HandshakeInfo {
75+
// GetHandshakeInfo returns a pointer to the *HandshakeInfo stored in attr.
76+
func GetHandshakeInfo(attr *attributes.Attributes) *unsafe.Pointer {
7777
v := attr.Value(handshakeAttrKey{})
78-
hi, _ := v.(*HandshakeInfo)
78+
hi, _ := v.(*unsafe.Pointer)
7979
return hi
8080
}
8181

8282
// HandshakeInfo wraps all the security configuration required by client and
8383
// server handshake methods in xds credentials. The xDS implementation will be
8484
// responsible for populating these fields.
85-
//
86-
// Safe for concurrent access.
8785
type HandshakeInfo struct {
88-
mu sync.Mutex
86+
// All fields written at init time and read only after that, so no
87+
// synchronization needed.
8988
rootProvider certprovider.Provider
9089
identityProvider certprovider.Provider
9190
sanMatchers []matcher.StringMatcher // Only on the client side.
9291
requireClientCert bool // Only on server side.
9392
}
9493

95-
// SetRootCertProvider updates the root certificate provider.
96-
func (hi *HandshakeInfo) SetRootCertProvider(root certprovider.Provider) {
97-
hi.mu.Lock()
98-
hi.rootProvider = root
99-
hi.mu.Unlock()
100-
}
101-
102-
// SetIdentityCertProvider updates the identity certificate provider.
103-
func (hi *HandshakeInfo) SetIdentityCertProvider(identity certprovider.Provider) {
104-
hi.mu.Lock()
105-
hi.identityProvider = identity
106-
hi.mu.Unlock()
107-
}
108-
109-
// SetSANMatchers updates the list of SAN matchers.
110-
func (hi *HandshakeInfo) SetSANMatchers(sanMatchers []matcher.StringMatcher) {
111-
hi.mu.Lock()
112-
hi.sanMatchers = sanMatchers
113-
hi.mu.Unlock()
114-
}
115-
116-
// SetRequireClientCert updates whether a client cert is required during the
117-
// ServerHandshake(). A value of true indicates that we are performing mTLS.
118-
func (hi *HandshakeInfo) SetRequireClientCert(require bool) {
119-
hi.mu.Lock()
120-
hi.requireClientCert = require
121-
hi.mu.Unlock()
94+
func NewHandshakeInfo(rootProvider certprovider.Provider, identityProvider certprovider.Provider, sanMatchers []matcher.StringMatcher, requireClientCert bool) *HandshakeInfo {
95+
return &HandshakeInfo{
96+
rootProvider: rootProvider,
97+
identityProvider: identityProvider,
98+
sanMatchers: sanMatchers,
99+
requireClientCert: requireClientCert,
100+
}
122101
}
123102

124103
// UseFallbackCreds returns true when fallback credentials are to be used based
@@ -127,24 +106,18 @@ func (hi *HandshakeInfo) UseFallbackCreds() bool {
127106
if hi == nil {
128107
return true
129108
}
130-
131-
hi.mu.Lock()
132-
defer hi.mu.Unlock()
133109
return hi.identityProvider == nil && hi.rootProvider == nil
134110
}
135111

136112
// GetSANMatchersForTesting returns the SAN matchers stored in HandshakeInfo.
137113
// To be used only for testing purposes.
138114
func (hi *HandshakeInfo) GetSANMatchersForTesting() []matcher.StringMatcher {
139-
hi.mu.Lock()
140-
defer hi.mu.Unlock()
141115
return append([]matcher.StringMatcher{}, hi.sanMatchers...)
142116
}
143117

144118
// ClientSideTLSConfig constructs a tls.Config to be used in a client-side
145119
// handshake based on the contents of the HandshakeInfo.
146120
func (hi *HandshakeInfo) ClientSideTLSConfig(ctx context.Context) (*tls.Config, error) {
147-
hi.mu.Lock()
148121
// On the client side, rootProvider is mandatory. IdentityProvider is
149122
// optional based on whether the client is doing TLS or mTLS.
150123
if hi.rootProvider == nil {
@@ -153,7 +126,6 @@ func (hi *HandshakeInfo) ClientSideTLSConfig(ctx context.Context) (*tls.Config,
153126
// Since the call to KeyMaterial() can block, we read the providers under
154127
// the lock but call the actual function after releasing the lock.
155128
rootProv, idProv := hi.rootProvider, hi.identityProvider
156-
hi.mu.Unlock()
157129

158130
// InsecureSkipVerify needs to be set to true because we need to perform
159131
// custom verification to check the SAN on the received certificate.
@@ -188,7 +160,6 @@ func (hi *HandshakeInfo) ServerSideTLSConfig(ctx context.Context) (*tls.Config,
188160
ClientAuth: tls.NoClientCert,
189161
NextProtos: []string{"h2"},
190162
}
191-
hi.mu.Lock()
192163
// On the server side, identityProvider is mandatory. RootProvider is
193164
// optional based on whether the server is doing TLS or mTLS.
194165
if hi.identityProvider == nil {
@@ -200,7 +171,6 @@ func (hi *HandshakeInfo) ServerSideTLSConfig(ctx context.Context) (*tls.Config,
200171
if hi.requireClientCert {
201172
cfg.ClientAuth = tls.RequireAndVerifyClientCert
202173
}
203-
hi.mu.Unlock()
204174

205175
// identityProvider is mandatory on the server side.
206176
km, err := idProv.KeyMaterial(ctx)
@@ -225,8 +195,6 @@ func (hi *HandshakeInfo) ServerSideTLSConfig(ctx context.Context) (*tls.Config,
225195
// If the list of SAN matchers in the HandshakeInfo is empty, this function
226196
// returns true for all input certificates.
227197
func (hi *HandshakeInfo) MatchingSANExists(cert *x509.Certificate) bool {
228-
hi.mu.Lock()
229-
defer hi.mu.Unlock()
230198
if len(hi.sanMatchers) == 0 {
231199
return true
232200
}
@@ -325,9 +293,3 @@ func dnsMatch(host, san string) bool {
325293
hostPrefix := strings.TrimSuffix(host, san[1:])
326294
return !strings.Contains(hostPrefix, ".")
327295
}
328-
329-
// NewHandshakeInfo returns a new instance of HandshakeInfo with the given root
330-
// and identity certificate providers.
331-
func NewHandshakeInfo(root, identity certprovider.Provider) *HandshakeInfo {
332-
return &HandshakeInfo{rootProvider: root, identityProvider: identity}
333-
}

internal/credentials/xds/handshake_info_test.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -188,8 +188,7 @@ func TestMatchingSANExists_FailureCases(t *testing.T) {
188188

189189
for _, test := range tests {
190190
t.Run(test.desc, func(t *testing.T) {
191-
hi := NewHandshakeInfo(nil, nil)
192-
hi.SetSANMatchers(test.sanMatchers)
191+
hi := NewHandshakeInfo(nil, nil, test.sanMatchers, false)
193192

194193
if hi.MatchingSANExists(inputCert) {
195194
t.Fatalf("hi.MatchingSANExists(%+v) with SAN matchers +%v succeeded when expected to fail", inputCert, test.sanMatchers)
@@ -289,8 +288,7 @@ func TestMatchingSANExists_Success(t *testing.T) {
289288

290289
for _, test := range tests {
291290
t.Run(test.desc, func(t *testing.T) {
292-
hi := NewHandshakeInfo(nil, nil)
293-
hi.SetSANMatchers(test.sanMatchers)
291+
hi := NewHandshakeInfo(nil, nil, test.sanMatchers, false)
294292

295293
if !hi.MatchingSANExists(inputCert) {
296294
t.Fatalf("hi.MatchingSANExists(%+v) with SAN matchers +%v failed when expected to succeed", inputCert, test.sanMatchers)

internal/internal.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ var (
5757
// GetXDSHandshakeInfoForTesting returns a pointer to the xds.HandshakeInfo
5858
// stored in the passed in attributes. This is set by
5959
// credentials/xds/xds.go.
60-
GetXDSHandshakeInfoForTesting any // func (*attributes.Attributes) *xds.HandshakeInfo
60+
GetXDSHandshakeInfoForTesting any // func (*attributes.Attributes) *unsafe.Pointer
6161
// GetServerCredentials returns the transport credentials configured on a
6262
// gRPC server. An xDS-enabled server needs to know what type of credentials
6363
// is configured on the underlying gRPC server. This is set by server.go.

0 commit comments

Comments
 (0)