Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions resolver/map.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,116 @@ func (a *AddressMap) Values() []any {
}
return ret
}

type endpointNode struct {
addrs map[string]struct{}
}

// Equal returns whether the unordered set of addrs are the same between the
// endpoint nodes.
func (en *endpointNode) Equal(en2 *endpointNode) bool {
if len(en.addrs) != len(en2.addrs) {
return false
}
for addr := range en.addrs {
if _, ok := en2.addrs[addr]; !ok {
return false
}
}
return true
}

func toEndpointNode(endpoint Endpoint) endpointNode {
en := make(map[string]struct{})
for _, addr := range endpoint.Addresses {
en[addr.Addr] = struct{}{}
}
return endpointNode{
addrs: en,
}
}

// EndpointMap is a map of endpoints to arbitrary values keyed on only the
// unordered set of address strings within an endpoint. This map is not thread
// safe, thus it is unsafe to access concurrently. Must be created via
// NewEndpointMap; do not construct directly.
type EndpointMap struct {
endpoints map[*endpointNode]any
}

// NewEndpointMap creates a new EndpointMap.
func NewEndpointMap() *EndpointMap {
return &EndpointMap{
endpoints: make(map[*endpointNode]any),
}
}

// Get returns the value for the address in the map, if present.
func (em *EndpointMap) Get(e Endpoint) (value any, ok bool) {
en := toEndpointNode(e)
if endpoint := em.find(en); endpoint != nil {
return em.endpoints[endpoint], true
}
return nil, false
}

// Set updates or adds the value to the address in the map.
func (em *EndpointMap) Set(e Endpoint, value any) {
en := toEndpointNode(e)
if endpoint := em.find(en); endpoint != nil {
em.endpoints[endpoint] = value
return
}
em.endpoints[&en] = value
}

// Len returns the number of entries in the map.
func (em *EndpointMap) Len() int {
return len(em.endpoints)
}

// Keys returns a slice of all current map keys, as endpoints specifying the
// addresses present in the endpoint keys, in which uniqueness is determined by
// the unordered set of addresses. Thus, endpoint information returned is not
// the full endpoint data (drops duplicated addresses and attributes) but can be
// used for EndpointMap accesses.
func (em *EndpointMap) Keys() []Endpoint {
ret := make([]Endpoint, 0, len(em.endpoints))
for en := range em.endpoints {
var endpoint Endpoint
for addr := range en.addrs {
endpoint.Addresses = append(endpoint.Addresses, Address{Addr: addr})
}
ret = append(ret, endpoint)
}
return ret
}

// Values returns a slice of all current map values.
func (em *EndpointMap) Values() []any {
ret := make([]any, 0, len(em.endpoints))
for _, val := range em.endpoints {
ret = append(ret, val)
}
return ret
}

// find returns a pointer to the endpoint node in em if the endpoint node is
// already present. If not found, nil is returned. The comparisons are done on
// the unordered set of addresses within an endpoint.
func (em EndpointMap) find(e endpointNode) *endpointNode {
for endpoint := range em.endpoints {
if e.Equal(endpoint) {
return endpoint
}
}
return nil
}

// Delete removes the specified endpoint from the map.
func (em *EndpointMap) Delete(e Endpoint) {
en := toEndpointNode(e)
if entry := em.find(en); entry != nil {
delete(em.endpoints, entry)
}
}
126 changes: 126 additions & 0 deletions resolver/map_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,17 @@ var (
addr5 = Address{Addr: "a1", Attributes: attributes.New("a1", "3"), ServerName: "s1"}
addr6 = Address{Addr: "a1", Attributes: attributes.New("a1", 3), ServerName: "s2"}
addr7 = Address{Addr: "a1", Attributes: attributes.New("a1", 3), ServerName: "s1", BalancerAttributes: attributes.New("xx", 3)}

endpoint1 = Endpoint{Addresses: []Address{{Addr: "addr1"}}}
endpoint2 = Endpoint{Addresses: []Address{{Addr: "addr2"}}}
endpoint3 = Endpoint{Addresses: []Address{{Addr: "addr3"}}}
endpoint4 = Endpoint{Addresses: []Address{{Addr: "addr4"}}}
endpoint5 = Endpoint{Addresses: []Address{{Addr: "addr5"}}}
endpoint6 = Endpoint{Addresses: []Address{{Addr: "addr6"}}}
endpoint7 = Endpoint{Addresses: []Address{{Addr: "addr7"}}}
endpoint12 = Endpoint{Addresses: []Address{{Addr: "addr1"}, {Addr: "addr2"}}}
endpoint21 = Endpoint{Addresses: []Address{{Addr: "addr2"}, {Addr: "addr1"}}}
endpoint123 = Endpoint{Addresses: []Address{{Addr: "addr1"}, {Addr: "addr2"}, {Addr: "addr3"}}}
)

func (s) TestAddressMap_Length(t *testing.T) {
Expand Down Expand Up @@ -161,3 +172,118 @@ func (s) TestAddressMap_Values(t *testing.T) {
t.Fatalf("addrMap.Values returned unexpected elements (-want, +got):\n%v", diff)
}
}

func (s) TestEndpointMap_Length(t *testing.T) {
em := NewEndpointMap()
// Should be empty at creation time.
if got := em.Len(); got != 0 {
t.Fatalf("em.Len() = %v; want 0", got)
}
// Add two endpoints with the same unordered set of addresses. This should
// amount to one endpoint. It should also not take into account attributes.
em.Set(endpoint12, struct{}{})
em.Set(endpoint21, struct{}{})

if got := em.Len(); got != 1 {
t.Fatalf("em.Len() = %v; want 1", got)
}

// Add another unique endpoint. This should cause the length to be 2.
em.Set(endpoint123, struct{}{})
if got := em.Len(); got != 2 {
t.Fatalf("em.Len() = %v; want 2", got)
}
}

func (s) TestEndpointMap_Get(t *testing.T) {
em := NewEndpointMap()
em.Set(endpoint1, 1)
// The second endpoint endpoint21 should override.
em.Set(endpoint12, 1)
em.Set(endpoint21, 2)
em.Set(endpoint3, 3)
em.Set(endpoint4, 4)
em.Set(endpoint5, 5)
em.Set(endpoint6, 6)
em.Set(endpoint7, 7)

if got, ok := em.Get(endpoint1); !ok || got.(int) != 1 {
t.Fatalf("em.Get(endpoint1) = %v, %v; want %v, true", got, ok, 1)
}
if got, ok := em.Get(endpoint12); !ok || got.(int) != 2 {
t.Fatalf("em.Get(endpoint12) = %v, %v; want %v, true", got, ok, 2)
}
if got, ok := em.Get(endpoint21); !ok || got.(int) != 2 {
t.Fatalf("em.Get(endpoint21) = %v, %v; want %v, true", got, ok, 2)
}
if got, ok := em.Get(endpoint3); !ok || got.(int) != 3 {
t.Fatalf("em.Get(endpoint1) = %v, %v; want %v, true", got, ok, 3)
}
if got, ok := em.Get(endpoint4); !ok || got.(int) != 4 {
t.Fatalf("em.Get(endpoint1) = %v, %v; want %v, true", got, ok, 4)
}
if got, ok := em.Get(endpoint5); !ok || got.(int) != 5 {
t.Fatalf("em.Get(endpoint1) = %v, %v; want %v, true", got, ok, 5)
}
if got, ok := em.Get(endpoint6); !ok || got.(int) != 6 {
t.Fatalf("em.Get(endpoint1) = %v, %v; want %v, true", got, ok, 6)
}
if got, ok := em.Get(endpoint7); !ok || got.(int) != 7 {
t.Fatalf("em.Get(endpoint1) = %v, %v; want %v, true", got, ok, 7)
}
if _, ok := em.Get(endpoint123); ok {
t.Fatalf("em.Get(endpoint123) = _, %v; want _, false", ok)
}
}

func (s) TestEndpointMap_Delete(t *testing.T) {
em := NewEndpointMap()
// Initial state of system: [1, 2, 3, 12]
em.Set(endpoint1, struct{}{})
em.Set(endpoint2, struct{}{})
em.Set(endpoint3, struct{}{})
em.Set(endpoint12, struct{}{})
// Delete: [2, 21]
em.Delete(endpoint2)
em.Delete(endpoint21)

// [1, 3] should be present:
if _, ok := em.Get(endpoint1); !ok {
t.Fatalf("em.Get(endpoint1) = %v, want true", ok)
}
if _, ok := em.Get(endpoint3); !ok {
t.Fatalf("em.Get(endpoint3) = %v, want true", ok)
}
// [2, 12] should not be present:
if _, ok := em.Get(endpoint2); ok {
t.Fatalf("em.Get(endpoint2) = %v, want false", ok)
}
if _, ok := em.Get(endpoint12); ok {
t.Fatalf("em.Get(endpoint12) = %v, want false", ok)
}
if _, ok := em.Get(endpoint21); ok {
t.Fatalf("em.Get(endpoint21) = %v, want false", ok)
}
}

func (s) TestEndpointMap_Values(t *testing.T) {
em := NewEndpointMap()
em.Set(endpoint1, 1)
// The second endpoint endpoint21 should override.
em.Set(endpoint12, 1)
em.Set(endpoint21, 2)
em.Set(endpoint3, 3)
em.Set(endpoint4, 4)
em.Set(endpoint5, 5)
em.Set(endpoint6, 6)
em.Set(endpoint7, 7)
want := []int{1, 2, 3, 4, 5, 6, 7}
var got []int
for _, v := range em.Values() {
got = append(got, v.(int))
}
sort.Ints(got)
if diff := cmp.Diff(want, got); diff != "" {
t.Fatalf("em.Values() returned unexpected elements (-want, +got):\n%v", diff)
}
}