@@ -23,6 +23,7 @@ import (
2323 "crypto/tls"
2424 "crypto/x509"
2525 "fmt"
26+ "net"
2627 "os"
2728 "strings"
2829 "testing"
@@ -31,6 +32,7 @@ import (
3132 "google.golang.org/grpc"
3233 "google.golang.org/grpc/codes"
3334 "google.golang.org/grpc/credentials"
35+ "google.golang.org/grpc/internal/envconfig"
3436 "google.golang.org/grpc/internal/grpctest"
3537 "google.golang.org/grpc/internal/stubserver"
3638 "google.golang.org/grpc/status"
@@ -236,3 +238,160 @@ func (s) TestTLS_CipherSuitesOverridable(t *testing.T) {
236238 t .Fatalf ("EmptyCall err = %v; want <nil>" , err )
237239 }
238240}
241+
242+ // TestTLS_DisabledALPNClient tests the behaviour of TransportCredentials when
243+ // connecting to a server that doesn't support ALPN.
244+ func (s ) TestTLS_DisabledALPNClient (t * testing.T ) {
245+ initialVal := envconfig .EnforceALPNEnabled
246+ defer func () {
247+ envconfig .EnforceALPNEnabled = initialVal
248+ }()
249+
250+ tests := []struct {
251+ name string
252+ alpnEnforced bool
253+ wantErr bool
254+ }{
255+ {
256+ name : "enforced" ,
257+ alpnEnforced : true ,
258+ wantErr : true ,
259+ },
260+ {
261+ name : "not_enforced" ,
262+ },
263+ }
264+
265+ for _ , tc := range tests {
266+ t .Run (tc .name , func (t * testing.T ) {
267+ envconfig .EnforceALPNEnabled = tc .alpnEnforced
268+
269+ listener , err := tls .Listen ("tcp" , "localhost:0" , & tls.Config {
270+ Certificates : []tls.Certificate {serverCert },
271+ NextProtos : []string {}, // Empty list indicates ALPN is disabled.
272+ })
273+ if err != nil {
274+ t .Fatalf ("Error starting TLS server: %v" , err )
275+ }
276+
277+ errCh := make (chan error , 1 )
278+ go func () {
279+ conn , err := listener .Accept ()
280+ if err != nil {
281+ errCh <- fmt .Errorf ("listener.Accept returned error: %v" , err )
282+ } else {
283+ // The first write to the TLS listener initiates the TLS handshake.
284+ conn .Write ([]byte ("Hello, World!" ))
285+ conn .Close ()
286+ }
287+ close (errCh )
288+ }()
289+
290+ serverAddr := listener .Addr ().String ()
291+ conn , err := net .Dial ("tcp" , serverAddr )
292+ if err != nil {
293+ t .Fatalf ("net.Dial(%s) failed: %v" , serverAddr , err )
294+ }
295+ defer conn .Close ()
296+
297+ ctx , cancel := context .WithTimeout (context .Background (), defaultTestTimeout )
298+ defer cancel ()
299+
300+ clientCfg := tls.Config {
301+ ServerName : serverName ,
302+ RootCAs : certPool ,
303+ NextProtos : []string {"h2" },
304+ }
305+ _ , _ , err = credentials .NewTLS (& clientCfg ).ClientHandshake (ctx , serverName , conn )
306+
307+ if gotErr := (err != nil ); gotErr != tc .wantErr {
308+ t .Errorf ("ClientHandshake returned unexpected error: got=%v, want=%t" , err , tc .wantErr )
309+ }
310+
311+ select {
312+ case err := <- errCh :
313+ if err != nil {
314+ t .Fatalf ("Unexpected error received from server: %v" , err )
315+ }
316+ case <- ctx .Done ():
317+ t .Fatalf ("Timeout waiting for error from server" )
318+ }
319+ })
320+ }
321+ }
322+
323+ // TestTLS_DisabledALPNServer tests the behaviour of TransportCredentials when
324+ // accepting a request from a client that doesn't support ALPN.
325+ func (s ) TestTLS_DisabledALPNServer (t * testing.T ) {
326+ initialVal := envconfig .EnforceALPNEnabled
327+ defer func () {
328+ envconfig .EnforceALPNEnabled = initialVal
329+ }()
330+
331+ tests := []struct {
332+ name string
333+ alpnEnforced bool
334+ wantErr bool
335+ }{
336+ {
337+ name : "enforced" ,
338+ alpnEnforced : true ,
339+ wantErr : true ,
340+ },
341+ {
342+ name : "not_enforced" ,
343+ },
344+ }
345+
346+ for _ , tc := range tests {
347+ t .Run (tc .name , func (t * testing.T ) {
348+ envconfig .EnforceALPNEnabled = tc .alpnEnforced
349+
350+ listener , err := net .Listen ("tcp" , "localhost:0" )
351+ if err != nil {
352+ t .Fatalf ("Error starting server: %v" , err )
353+ }
354+
355+ errCh := make (chan error , 1 )
356+ go func () {
357+ conn , err := listener .Accept ()
358+ if err != nil {
359+ errCh <- fmt .Errorf ("listener.Accept returned error: %v" , err )
360+ return
361+ }
362+ defer conn .Close ()
363+ serverCfg := tls.Config {
364+ Certificates : []tls.Certificate {serverCert },
365+ NextProtos : []string {"h2" },
366+ }
367+ _ , _ , err = credentials .NewTLS (& serverCfg ).ServerHandshake (conn )
368+ if gotErr := (err != nil ); gotErr != tc .wantErr {
369+ t .Errorf ("ServerHandshake returned unexpected error: got=%v, want=%t" , err , tc .wantErr )
370+ }
371+ close (errCh )
372+ }()
373+
374+ serverAddr := listener .Addr ().String ()
375+ clientCfg := & tls.Config {
376+ Certificates : []tls.Certificate {serverCert },
377+ NextProtos : []string {}, // Empty list indicates ALPN is disabled.
378+ RootCAs : certPool ,
379+ ServerName : serverName ,
380+ }
381+ conn , err := tls .Dial ("tcp" , serverAddr , clientCfg )
382+ if err != nil {
383+ t .Fatalf ("tls.Dial(%s) failed: %v" , serverAddr , err )
384+ }
385+ defer conn .Close ()
386+
387+ select {
388+ case <- time .After (defaultTestTimeout ):
389+ t .Fatal ("Timed out waiting for completion" )
390+ case err := <- errCh :
391+ if err != nil {
392+ t .Fatalf ("Unexpected server error: %v" , err )
393+ }
394+ }
395+ })
396+ }
397+ }
0 commit comments