Skip to content

Commit 24a6b48

Browse files
authored
credentials/alts: fix defer in TestDial (#7301)
1 parent e37c6e8 commit 24a6b48

File tree

2 files changed

+18
-23
lines changed

2 files changed

+18
-23
lines changed

credentials/alts/internal/handshaker/service/service.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,6 @@ var (
3434
// to a corresponding connection to a hypervisor handshaker service
3535
// instance.
3636
hsConnMap = make(map[string]*grpc.ClientConn)
37-
// hsDialer will be reassigned in tests.
38-
hsDialer = grpc.Dial
3937
)
4038

4139
// Dial dials the handshake service in the hypervisor. If a connection has
@@ -50,7 +48,7 @@ func Dial(hsAddress string) (*grpc.ClientConn, error) {
5048
// Create a new connection to the handshaker service. Note that
5149
// this connection stays open until the application is closed.
5250
var err error
53-
hsConn, err = hsDialer(hsAddress, grpc.WithTransportCredentials(insecure.NewCredentials()))
51+
hsConn, err = grpc.Dial(hsAddress, grpc.WithTransportCredentials(insecure.NewCredentials()))
5452
if err != nil {
5553
return nil, err
5654
}

credentials/alts/internal/handshaker/service/service_test.go

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,34 +21,33 @@ package service
2121
import (
2222
"testing"
2323

24-
grpc "google.golang.org/grpc"
24+
"google.golang.org/grpc/internal/grpctest"
2525
)
2626

27+
type s struct {
28+
grpctest.Tester
29+
}
30+
31+
func Test(t *testing.T) {
32+
grpctest.RunSubTests(t, s{})
33+
}
34+
2735
const (
2836
testAddress1 = "some_address_1"
2937
testAddress2 = "some_address_2"
3038
)
3139

32-
func TestDial(t *testing.T) {
33-
defer func() func() {
34-
temp := hsDialer
35-
hsDialer = func(target string, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
36-
return &grpc.ClientConn{}, nil
37-
}
38-
return func() {
39-
hsDialer = temp
40-
}
41-
}()
42-
40+
// TestDial verifies the behaviour of alts handshake when there are multiple Dials.
41+
// If a connection has already been established, this function returns it.
42+
// Otherwise, a new connection is created.
43+
func (s) TestDial(t *testing.T) {
4344
// First call to Dial, it should create a connection to the server running
4445
// at the given address.
4546
conn1, err := Dial(testAddress1)
4647
if err != nil {
4748
t.Fatalf("first call to Dial(%v) failed: %v", testAddress1, err)
4849
}
49-
if conn1 == nil {
50-
t.Fatalf("first call to Dial(%v)=(nil, _), want not nil", testAddress1)
51-
}
50+
defer conn1.Close()
5251
if got, want := hsConnMap[testAddress1], conn1; got != want {
5352
t.Fatalf("hsConnMap[%v]=%v, want %v", testAddress1, got, want)
5453
}
@@ -58,22 +57,20 @@ func TestDial(t *testing.T) {
5857
if err != nil {
5958
t.Fatalf("second call to Dial(%v) failed: %v", testAddress1, err)
6059
}
60+
defer conn2.Close()
6161
if got, want := conn2, conn1; got != want {
6262
t.Fatalf("second call to Dial(%v)=(%v, _), want (%v,. _)", testAddress1, got, want)
6363
}
6464
if got, want := hsConnMap[testAddress1], conn1; got != want {
6565
t.Fatalf("hsConnMap[%v]=%v, want %v", testAddress1, got, want)
6666
}
6767

68-
// Third call to Dial using a different address should create a new
69-
// connection.
68+
// Third call to Dial using a different address should create a new connection.
7069
conn3, err := Dial(testAddress2)
7170
if err != nil {
7271
t.Fatalf("third call to Dial(%v) failed: %v", testAddress2, err)
7372
}
74-
if conn3 == nil {
75-
t.Fatalf("third call to Dial(%v)=(nil, _), want not nil", testAddress2)
76-
}
73+
defer conn3.Close()
7774
if got, want := hsConnMap[testAddress2], conn3; got != want {
7875
t.Fatalf("hsConnMap[%v]=%v, want %v", testAddress2, got, want)
7976
}

0 commit comments

Comments
 (0)