Skip to content

Commit c4d55e0

Browse files
Unit test for HandShakeInfo.Equal
1 parent 9483fdb commit c4d55e0

File tree

1 file changed

+111
-0
lines changed

1 file changed

+111
-0
lines changed

internal/credentials/xds/handshake_info_test.go

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,27 @@
1919
package xds
2020

2121
import (
22+
"context"
2223
"crypto/x509"
2324
"net"
2425
"net/url"
2526
"regexp"
2627
"testing"
2728

29+
"google.golang.org/grpc/credentials/tls/certprovider"
2830
"google.golang.org/grpc/internal/xds/matcher"
2931
)
3032

33+
type mockCertProvider struct {
34+
id int
35+
}
36+
37+
func (d *mockCertProvider) KeyMaterial(_ context.Context) (*certprovider.KeyMaterial, error) {
38+
return &certprovider.KeyMaterial{}, nil
39+
}
40+
41+
func (d *mockCertProvider) Close() {}
42+
3143
func TestDNSMatch(t *testing.T) {
3244
tests := []struct {
3345
desc string
@@ -300,3 +312,102 @@ func TestMatchingSANExists_Success(t *testing.T) {
300312
func newStringP(s string) *string {
301313
return &s
302314
}
315+
316+
func TestEqual(t *testing.T) {
317+
mockProvider1 := &mockCertProvider{id: 1}
318+
mockProvider2 := &mockCertProvider{id: 2}
319+
320+
tests := []struct {
321+
desc string
322+
hi1 *HandshakeInfo
323+
hi2 *HandshakeInfo
324+
wantMatch bool
325+
}{
326+
{
327+
desc: "both HandshakeInfo are nil",
328+
hi1: nil,
329+
hi2: nil,
330+
wantMatch: true,
331+
},
332+
{
333+
desc: "one HandshakeInfo is nil",
334+
hi1: nil,
335+
hi2: NewHandshakeInfo(mockProvider1, nil, nil, false),
336+
wantMatch: false,
337+
},
338+
{
339+
desc: "different root providers",
340+
hi1: NewHandshakeInfo(mockProvider1, nil, nil, false),
341+
hi2: NewHandshakeInfo(mockProvider2, nil, nil, false),
342+
wantMatch: false,
343+
},
344+
{
345+
desc: "different identity providers",
346+
hi1: NewHandshakeInfo(nil, mockProvider1, nil, false),
347+
hi2: NewHandshakeInfo(nil, mockProvider2, nil, false),
348+
wantMatch: false,
349+
},
350+
{
351+
desc: "same providers, same SAN matchers",
352+
hi1: NewHandshakeInfo(mockProvider1, mockProvider1, []matcher.StringMatcher{
353+
matcher.StringMatcherForTesting(newStringP("foo.com"), nil, nil, nil, nil, false),
354+
}, false),
355+
hi2: NewHandshakeInfo(mockProvider1, mockProvider1, []matcher.StringMatcher{
356+
matcher.StringMatcherForTesting(newStringP("foo.com"), nil, nil, nil, nil, false),
357+
}, false),
358+
wantMatch: true,
359+
},
360+
{
361+
desc: "same providers, different SAN matchers",
362+
hi1: NewHandshakeInfo(mockProvider1, mockProvider1, []matcher.StringMatcher{
363+
matcher.StringMatcherForTesting(newStringP("foo.com"), nil, nil, nil, nil, false),
364+
}, false),
365+
hi2: NewHandshakeInfo(mockProvider1, mockProvider1, []matcher.StringMatcher{
366+
matcher.StringMatcherForTesting(newStringP("bar.com"), nil, nil, nil, nil, false),
367+
}, false),
368+
wantMatch: false,
369+
},
370+
{
371+
desc: "same SAN matchers with different content",
372+
hi1: NewHandshakeInfo(mockProvider1, mockProvider1, []matcher.StringMatcher{
373+
matcher.StringMatcherForTesting(newStringP("foo.com"), nil, nil, nil, nil, false),
374+
}, false),
375+
hi2: NewHandshakeInfo(mockProvider1, mockProvider1, []matcher.StringMatcher{
376+
matcher.StringMatcherForTesting(newStringP("foo.com"), nil, nil, nil, nil, false),
377+
matcher.StringMatcherForTesting(newStringP("bar.com"), nil, nil, nil, nil, false),
378+
}, false),
379+
wantMatch: false,
380+
},
381+
{
382+
desc: "different requireClientCert flags",
383+
hi1: NewHandshakeInfo(mockProvider1, mockProvider1, nil, true),
384+
hi2: NewHandshakeInfo(mockProvider1, mockProvider1, nil, false),
385+
wantMatch: false,
386+
},
387+
{
388+
desc: "same rootProvider but different mockCertProvider state",
389+
hi1: &HandshakeInfo{
390+
rootProvider: mockProvider1,
391+
identityProvider: mockProvider1,
392+
sanMatchers: nil,
393+
requireClientCert: false,
394+
},
395+
hi2: &HandshakeInfo{
396+
rootProvider: &mockCertProvider{id: 1},
397+
identityProvider: mockProvider1,
398+
sanMatchers: nil,
399+
requireClientCert: false,
400+
},
401+
wantMatch: false,
402+
},
403+
}
404+
405+
for _, test := range tests {
406+
t.Run(test.desc, func(t *testing.T) {
407+
gotMatch := test.hi1.Equal(test.hi2)
408+
if gotMatch != test.wantMatch {
409+
t.Errorf("hi1.Equal(hi2) = %v; wantMatch %v", gotMatch, test.wantMatch)
410+
}
411+
})
412+
}
413+
}

0 commit comments

Comments
 (0)