Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 129 additions & 2 deletions chk/chk.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/openconfig/gribigo/client"
"github.com/openconfig/gribigo/fluent"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/testing/protocmp"

spb "google.golang.org/genproto/googleapis/rpc/status"
spb "github.com/openconfig/gribi/v1/proto/service"
gspb "google.golang.org/genproto/googleapis/rpc/status"
)

// resultOpt is an interface implemented by all options that can be
Expand Down Expand Up @@ -106,6 +108,50 @@ func HasResult(t testing.TB, res []*client.OpResult, want *client.OpResult, opt
}
}

// HasResultsCache implements an efficient mechanism to call HasResults across
// a large set of operations results. HasResultsCache checks whether each result
// in wants is present in res, using the options specified.
func HasResultsCache(t testing.TB, res, wants []*client.OpResult, opt ...resultOpt) {
t.Helper()

byOpID := map[uint64]*client.OpResult{}
byNHID := map[uint64]*client.OpResult{}
byNHGID := map[uint64]*client.OpResult{}
byIPv4Prefix := map[string]*client.OpResult{}

for _, r := range res {
byOpID[r.OperationID] = r
if r.Details != nil {
switch {
case r.Details.NextHopGroupID != 0:
byNHGID[r.Details.NextHopGroupID] = r
case r.Details.NextHopIndex != 0:
byNHID[r.Details.NextHopIndex] = r
case r.Details.IPv4Prefix != "":
byIPv4Prefix[r.Details.IPv4Prefix] = r
}
}
}

if !hasIgnoreOperationID(opt) {
for _, want := range wants {
HasResult(t, []*client.OpResult{byOpID[want.OperationID]}, want, opt...)
}
return
}

for _, want := range wants {
switch {
case want.Details.NextHopGroupID != 0:
HasResult(t, []*client.OpResult{byNHGID[want.Details.NextHopGroupID]}, want, opt...)
case want.Details.NextHopIndex != 0:
HasResult(t, []*client.OpResult{byNHID[want.Details.NextHopIndex]}, want, opt...)
case want.Details.IPv4Prefix != "":
HasResult(t, []*client.OpResult{byIPv4Prefix[want.Details.IPv4Prefix]}, want, opt...)
}
}
}

// clientError converts the given error into a client ClientErr.
func clientError(t testing.TB, err error) *client.ClientErr {
t.Helper()
Expand Down Expand Up @@ -155,13 +201,15 @@ func AllowUnimplemented() *allowUnimplemented {

// HasRecvClientErrorWithStatus checks whether the supplied ClientErr ce contains a status with
// the code and details set to the values supplied in want.
//
// TODO(robjs): Add unit test for this check.
func HasRecvClientErrorWithStatus(t testing.TB, err error, want *status.Status, opts ...ErrorOpt) {
t.Helper()

okMsgs := []*status.Status{want}
for _, o := range opts {
if _, ok := o.(*allowUnimplemented); ok {
uProto := proto.Clone(want.Proto()).(*spb.Status)
uProto := proto.Clone(want.Proto()).(*gspb.Status)
uProto.Code = int32(codes.Unimplemented)
unimpl := status.FromProto(uProto)
okMsgs = append(okMsgs, unimpl)
Expand All @@ -188,3 +236,82 @@ func HasRecvClientErrorWithStatus(t testing.TB, err error, want *status.Status,
t.Fatalf("client does not have receive error with status %s, got: %v", want.Proto(), ce.Recv)
}
}

// GetResponseHasEntry checks whether the supplied GetResponse has the gRIBI
// entry described by the specified want within it. It calls t.Fatalf if no
// such entry is found.
func GetResponseHasEntries(t testing.TB, getres *spb.GetResponse, wants ...fluent.GRIBIEntry) {
// proto.Equal tends to be expensive, so start with building a cache
// so that we do not loop each time. We have to do this by network
// instance, because each NI has its own namespace for each included
// value.

type cache struct {
ipv4 map[string]*spb.AFTEntry
nhg map[uint64]*spb.AFTEntry
nh map[uint64]*spb.AFTEntry
}

netinsts := map[string]*cache{}

for _, r := range getres.GetEntry() {
if _, ok := netinsts[r.NetworkInstance]; !ok {
netinsts[r.NetworkInstance] = &cache{
ipv4: make(map[string]*spb.AFTEntry),
nhg: make(map[uint64]*spb.AFTEntry),
nh: make(map[uint64]*spb.AFTEntry),
}
}
ni := netinsts[r.NetworkInstance]

switch v := r.Entry.(type) {
case *spb.AFTEntry_NextHopGroup:
if id := v.NextHopGroup.GetId(); id != 0 {
ni.nhg[id] = r
}
case *spb.AFTEntry_NextHop:
if idx := v.NextHop.GetIndex(); idx != 0 {
ni.nh[idx] = r
}
case *spb.AFTEntry_Ipv4:
if pfx := v.Ipv4.GetPrefix(); pfx != "" {
ni.ipv4[pfx] = r
}
}
}

for _, want := range wants {
wantProto, err := want.EntryProto()
if err != nil {
t.Fatalf("cannot convert want to an AFTEntry protobuf, %v", err)
}

if wantProto.GetNetworkInstance() == "" {
t.Fatalf("got nil network instance, required.")
}

if wantProto.GetEntry() == nil {
t.Fatalf("got nil entry, required")
}

ni, ok := netinsts[wantProto.GetNetworkInstance()]
if !ok {
t.Fatalf("did not find entry, got: %s, did not find network instance in want: %s", wantProto.NetworkInstance, getres)
}

switch v := wantProto.Entry.(type) {
case *spb.AFTEntry_NextHopGroup:
if _, ok := ni.nhg[v.NextHopGroup.GetId()]; !ok {
t.Fatalf("did not find entry, did not find nexthop group: %s, got:\n%s", v.NextHopGroup, getres)
}
case *spb.AFTEntry_NextHop:
if _, ok := ni.nh[v.NextHop.GetIndex()]; !ok {
t.Fatalf("did not find entry, did not find nexthop: %s, got:\n%s", v.NextHop, getres)
}
case *spb.AFTEntry_Ipv4:
if _, ok := ni.ipv4[v.Ipv4.GetPrefix()]; !ok {
t.Fatalf("did not find entry, did not find ipv4: %s, got: %s\n", v.Ipv4, getres)
}
}
}
}
Loading