diff --git a/dataplane/saiserver/routing.go b/dataplane/saiserver/routing.go index b4bf970d..5ea6db0a 100644 --- a/dataplane/saiserver/routing.go +++ b/dataplane/saiserver/routing.go @@ -1273,9 +1273,17 @@ func (vlan *vlan) RemoveVlan(ctx context.Context, r *saipb.RemoveVlanRequest) (* if vId == DefaultVlanId { return nil, fmt.Errorf("cannot remove default VLAN") } - for _, v := range vlan.vlans[r.GetOid()] { + + vlan.mu.Lock() + var memberOids []uint64 + for mOid := range vlan.vlans[r.GetOid()] { + memberOids = append(memberOids, mOid) + } + vlan.mu.Unlock() + + for _, mOid := range memberOids { _, err := attrmgr.InvokeAndSave(ctx, vlan.mgr, vlan.RemoveVlanMember, &saipb.RemoveVlanMemberRequest{ - Oid: v.Oid, + Oid: mOid, }) if err != nil { return nil, err @@ -1303,14 +1311,30 @@ func (vlan *vlan) CreateVlanMember(ctx context.Context, r *saipb.CreateVlanMembe if err != nil { return nil, err } - member := vlan.memberByPortId(r.GetBridgePortId()) // Keep the vlan member if this port was assigned to any VLAN. + + // Get the port ID from the bridge port. + bpAttrReq := &saipb.GetBridgePortAttributeRequest{ + Oid: r.GetBridgePortId(), + AttrType: []saipb.BridgePortAttr{saipb.BridgePortAttr_BRIDGE_PORT_ATTR_PORT_ID}, + } + bpAttrResp := &saipb.GetBridgePortAttributeResponse{} + if err := vlan.mgr.PopulateAttributes(bpAttrReq, bpAttrResp); err != nil { + return nil, err + } + + portID := bpAttrResp.GetAttr().GetPortId() + + // For the case where this port was assigned to a prior vlan, store the + // vlan member to remove it from cache later. + member := vlan.memberByPortId(portID) + mOid := vlan.mgr.NextID() nid, err := vlan.dataplane.ObjectNID(ctx, &fwdpb.ObjectNIDRequest{ ContextId: &fwdpb.ContextId{Id: vlan.dataplane.ID()}, - ObjectId: &fwdpb.ObjectId{Id: fmt.Sprint(r.GetBridgePortId())}, + ObjectId: &fwdpb.ObjectId{Id: fmt.Sprint(portID)}, }) if err != nil { - slog.InfoContext(ctx, "Failed to find NID for port", "bridge_port", r.GetBridgePortId(), "err", err) + slog.InfoContext(ctx, "Failed to find NID for port", "port_id", portID, "err", err) return nil, err } vlanReq := fwdconfig.TableEntryAddRequest(vlan.dataplane.ID(), VlanTable).AppendEntry( @@ -1331,8 +1355,10 @@ func (vlan *vlan) CreateVlanMember(ctx context.Context, r *saipb.CreateVlanMembe vlanAttrResp.GetAttr().MemberList = append(vlanAttrResp.GetAttr().MemberList, mOid) vlan.mgr.StoreAttributes(vOid, vlanAttrResp.GetAttr()) vlan.mu.Lock() - vlan.vlans[vOid][mOid] = &vlanMember{Oid: mOid, PortID: r.GetBridgePortId(), Vid: vId, Mode: r.GetVlanTaggingMode()} + vlan.vlans[vOid][mOid] = &vlanMember{Oid: mOid, PortID: portID, Vid: vId, Mode: r.GetVlanTaggingMode()} vlan.mu.Unlock() + + // Fetch the original vlan from the old vlan member and remove the member from that vlan if member != nil { preVlanOid := vlan.oidByVId[member.Vid] vlanAttrReq = &saipb.GetVlanAttributeRequest{Oid: preVlanOid, AttrType: []saipb.VlanAttr{saipb.VlanAttr_VLAN_ATTR_MEMBER_LIST}} @@ -1357,8 +1383,23 @@ func (vlan *vlan) CreateVlanMember(ctx context.Context, r *saipb.CreateVlanMembe } func (vlan *vlan) RemoveVlanMember(ctx context.Context, r *saipb.RemoveVlanMemberRequest) (*saipb.RemoveVlanMemberResponse, error) { - member := vlan.memberByOid(r.GetOid()) - if member == nil { + vlan.mu.Lock() + defer vlan.mu.Unlock() + + var member *vlanMember + var targetVlanOid uint64 + found := false + + for vlanOid, members := range vlan.vlans { + if m, ok := members[r.GetOid()]; ok { + member = m + targetVlanOid = vlanOid + found = true + break + } + } + + if !found { return nil, fmt.Errorf("cannot find member with OID %d", r.GetOid()) } nid, err := vlan.dataplane.ObjectNID(ctx, &fwdpb.ObjectNIDRequest{ @@ -1373,6 +1414,9 @@ func (vlan *vlan) RemoveVlanMember(ctx context.Context, r *saipb.RemoveVlanMembe fwdconfig.EntryDesc(fwdconfig.ExactEntry(fwdconfig.PacketFieldBytes(fwdpb.PacketFieldNum_PACKET_FIELD_NUM_PACKET_PORT_INPUT).WithUint64(nid.GetNid())))).Build()); err != nil { return nil, err } + + delete(vlan.vlans[targetVlanOid], r.GetOid()) + return &saipb.RemoveVlanMemberResponse{}, nil } @@ -1444,6 +1488,8 @@ func (b *bridge) CreateBridgePort(ctx context.Context, req *saipb.CreateBridgePo adminState := req.GetAdminState() attrs := &saipb.BridgePortAttribute{ AdminState: proto.Bool(adminState), + PortId: proto.Uint64(req.GetPortId()), + Type: req.Type, } b.mgr.StoreAttributes(oid, attrs) return &saipb.CreateBridgePortResponse{ diff --git a/dataplane/saiserver/routing_test.go b/dataplane/saiserver/routing_test.go index 788b5bed..576dceff 100644 --- a/dataplane/saiserver/routing_test.go +++ b/dataplane/saiserver/routing_test.go @@ -1063,7 +1063,7 @@ func TestRemoveRouterInterface(t *testing.T) { FieldId: &fwdpb.PacketFieldId{Field: &fwdpb.PacketField{ FieldNum: fwdpb.PacketFieldNum_PACKET_FIELD_NUM_PACKET_PORT_INPUT, }}, - Bytes: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + Bytes: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05}, }, { FieldId: &fwdpb.PacketFieldId{Field: &fwdpb.PacketField{ FieldNum: fwdpb.PacketFieldNum_PACKET_FIELD_NUM_VLAN_TAG, @@ -1089,7 +1089,7 @@ func TestRemoveRouterInterface(t *testing.T) { FieldId: &fwdpb.PacketFieldId{Field: &fwdpb.PacketField{ FieldNum: fwdpb.PacketFieldNum_PACKET_FIELD_NUM_PACKET_PORT_INPUT, }}, - Bytes: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + Bytes: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05}, }, { FieldId: &fwdpb.PacketFieldId{Field: &fwdpb.PacketField{ FieldNum: fwdpb.PacketFieldNum_PACKET_FIELD_NUM_VLAN_TAG, diff --git a/dataplane/saiserver/switch_test.go b/dataplane/saiserver/switch_test.go index 5f9483b8..c35fe1b4 100644 --- a/dataplane/saiserver/switch_test.go +++ b/dataplane/saiserver/switch_test.go @@ -16,9 +16,11 @@ package saiserver import ( "context" + "fmt" "io" "log" "net" + "strconv" "testing" "github.com/google/go-cmp/cmp" @@ -219,6 +221,7 @@ type fakeSwitchDataplane struct { gotPortCreateReqs []*fwdpb.PortCreateRequest gotPortUpdateReqs []*fwdpb.PortUpdateRequest gotObjectDeleteReqs []*fwdpb.ObjectDeleteRequest + gotObjectNIDReqs []*fwdpb.ObjectNIDRequest gotFlowCounterCreateReqs []*fwdpb.FlowCounterCreateRequest gotFlowCounterQueryReqs []*fwdpb.FlowCounterQueryRequest gotEntryRemoveReqs []*fwdpb.TableEntryRemoveRequest @@ -247,6 +250,12 @@ func (f *fakeSwitchDataplane) TableEntryAdd(_ context.Context, req *fwdpb.TableE } func (f *fakeSwitchDataplane) TableEntryRemove(_ context.Context, req *fwdpb.TableEntryRemoveRequest) (*fwdpb.TableEntryRemoveReply, error) { + for _, prev := range f.gotEntryRemoveReqs { + if cmp.Equal(prev, req, protocmp.Transform()) { + return nil, fmt.Errorf("double table entry removal detected: %v", req) + } + } + f.gotEntryRemoveReqs = append(f.gotEntryRemoveReqs, req) return nil, nil } @@ -292,8 +301,12 @@ func (f *fakeSwitchDataplane) AttributeUpdate(context.Context, *fwdpb.AttributeU return nil, nil } -func (f *fakeSwitchDataplane) ObjectNID(context.Context, *fwdpb.ObjectNIDRequest) (*fwdpb.ObjectNIDReply, error) { - return nil, nil +func (f *fakeSwitchDataplane) ObjectNID(_ context.Context, req *fwdpb.ObjectNIDRequest) (*fwdpb.ObjectNIDReply, error) { + f.gotObjectNIDReqs = append(f.gotObjectNIDReqs, req) + + // Derive unique nid by using request object id to prevent collisions in tests. + id, _ := strconv.ParseUint(req.GetObjectId().GetId(), 10, 64) + return &fwdpb.ObjectNIDReply{Nid: id}, nil } func (f *fakeSwitchDataplane) InjectPacket(_ *fwdpb.ContextId, _ *fwdpb.PortId, _ fwdpb.PacketHeaderId, pkt []byte, _ []*fwdpb.ActionDesc, _ bool, _ fwdpb.PortAction) error { diff --git a/dataplane/saiserver/vlan_test.go b/dataplane/saiserver/vlan_test.go index c754b14a..a084664e 100644 --- a/dataplane/saiserver/vlan_test.go +++ b/dataplane/saiserver/vlan_test.go @@ -128,6 +128,11 @@ func TestVlanOperations(t *testing.T) { mgr.StoreAttributes(1, &saipb.SwitchAttribute{ DefaultStpInstId: proto.Uint64(testStpInstId), }) + // Crete bridge port mapping in attrmgr for each port + mgr.StoreAttributes(11, &saipb.BridgePortAttribute{PortId: proto.Uint64(11)}) + mgr.StoreAttributes(12, &saipb.BridgePortAttribute{PortId: proto.Uint64(12)}) + mgr.StoreAttributes(13, &saipb.BridgePortAttribute{PortId: proto.Uint64(13)}) + mgr.StoreAttributes(14, &saipb.BridgePortAttribute{PortId: proto.Uint64(14)}) ctx := context.TODO() getVLANMembers := func(vlanOID uint64) ([]uint64, error) { @@ -271,3 +276,107 @@ func newTestVlan(t testing.TB, api switchDataplaneAPI) (saipb.VlanClient, *attrm }) return saipb.NewVlanClient(conn), mgr, stopFn } + +func TestCreateVlanMemberLogicalMapping(t *testing.T) { + dplane := &fakeSwitchDataplane{} + c, mgr, stopFn := newTestVlan(t, dplane) + defer stopFn() + ctx := context.TODO() + + mgr.StoreAttributes(1, &saipb.SwitchAttribute{ + DefaultStpInstId: proto.Uint64(10), + }) + + // Create vlan. + resp, err := c.CreateVlan(ctx, &saipb.CreateVlanRequest{ + Switch: 1, + VlanId: proto.Uint32(10), + }) + if err != nil { + t.Fatalf("CreateVlan failed: %v", err) + } + vlanOid := resp.Oid + + // Create mock bridge port mapping in attrmgr for given port id. + bridgePortOid := uint64(100) + testPortId := uint64(10) + mgr.StoreAttributes(bridgePortOid, &saipb.BridgePortAttribute{ + PortId: proto.Uint64(testPortId), + }) + + // Create vlan member using bridge port oid. + _, err = c.CreateVlanMember(ctx, &saipb.CreateVlanMemberRequest{ + Switch: 1, + VlanId: &vlanOid, + BridgePortId: &bridgePortOid, + }) + if err != nil { + t.Fatalf("CreateVlanMember failed: %v", err) + } + + // Verify that ObjectNID was queried with the port id. + found := false + for _, req := range dplane.gotObjectNIDReqs { + if req.GetObjectId().GetId() == "10" { + found = true + break + } + } + if !found { + t.Errorf("ObjectNID was not queried with port id 10, got reqs: %+v", dplane.gotObjectNIDReqs) + } +} + +func TestRemoveVlan(t *testing.T) { + dplane := &fakeSwitchDataplane{} + c, mgr, stopFn := newTestVlan(t, dplane) + defer stopFn() + ctx := context.TODO() + + mgr.StoreAttributes(1, &saipb.SwitchAttribute{ + DefaultStpInstId: proto.Uint64(testStpInstId), + }) + + // Create vlan. + vlanResp, err := c.CreateVlan(ctx, &saipb.CreateVlanRequest{ + Switch: 1, + VlanId: proto.Uint32(10), + }) + if err != nil { + t.Fatalf("CreateVlan failed: %v", err) + } + vlanOid := vlanResp.Oid + + // Create mapping for mock bridge ports. + bridgePortOid := uint64(100) + physicalPortId := uint64(10) + mgr.StoreAttributes(bridgePortOid, &saipb.BridgePortAttribute{ + PortId: proto.Uint64(physicalPortId), + }) + + // Create vlan member with bridge port. + memberResp, err := c.CreateVlanMember(ctx, &saipb.CreateVlanMemberRequest{ + Switch: 1, + VlanId: &vlanOid, + BridgePortId: &bridgePortOid, + }) + if err != nil { + t.Fatalf("CreateVlanMember failed: %v", err) + } + + // Remove vlan member. + _, err = c.RemoveVlanMember(ctx, &saipb.RemoveVlanMemberRequest{ + Oid: memberResp.Oid, + }) + if err != nil { + t.Fatalf("RemoveVlanMember failed: %v", err) + } + + // Remove vlan. + _, err = c.RemoveVlan(ctx, &saipb.RemoveVlanRequest{ + Oid: vlanOid, + }) + if err != nil { + t.Fatalf("RemoveVlan failed: %v", err) + } +}