Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions btf/btf.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func (s *immutableTypes) typeByID(id TypeID) (Type, bool) {
// mutableTypes is a set of types which may be changed.
type mutableTypes struct {
imm immutableTypes
mu *sync.RWMutex // protects copies below
mu sync.RWMutex // protects copies below
copies map[Type]Type // map[orig]copy
copiedTypeIDs map[Type]TypeID // map[copy]origID
}
Expand Down Expand Up @@ -94,10 +94,14 @@ func (mt *mutableTypes) add(typ Type, typeIDs map[Type]TypeID) Type {
}

// copy a set of mutable types.
func (mt *mutableTypes) copy() mutableTypes {
mtCopy := mutableTypes{
func (mt *mutableTypes) copy() *mutableTypes {
if mt == nil {
return nil
}

mtCopy := &mutableTypes{
mt.imm,
&sync.RWMutex{},
sync.RWMutex{},
make(map[Type]Type, len(mt.copies)),
make(map[Type]TypeID, len(mt.copiedTypeIDs)),
}
Expand Down Expand Up @@ -169,7 +173,7 @@ func (mt *mutableTypes) anyTypesByName(name string) ([]Type, error) {
// Spec allows querying a set of Types and loading the set into the
// kernel.
type Spec struct {
mutableTypes
*mutableTypes

// String table from ELF.
strings *stringTable
Expand Down Expand Up @@ -339,15 +343,15 @@ func loadRawSpec(btf io.ReaderAt, bo binary.ByteOrder, base *Spec) (*Spec, error
typeIDs, typesByName := indexTypes(types, firstTypeID)

return &Spec{
mutableTypes{
&mutableTypes{
immutableTypes{
types,
typeIDs,
firstTypeID,
typesByName,
bo,
},
&sync.RWMutex{},
sync.RWMutex{},
make(map[Type]Type),
make(map[Type]TypeID),
},
Expand Down Expand Up @@ -522,6 +526,10 @@ func fixupDatasecLayout(ds *Datasec) error {

// Copy creates a copy of Spec.
func (s *Spec) Copy() *Spec {
if s == nil {
return nil
}

return &Spec{
s.mutableTypes.copy(),
s.strings,
Expand Down
2 changes: 2 additions & 0 deletions btf/btf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,8 @@ func TestGuessBTFByteOrder(t *testing.T) {
}

func TestSpecCopy(t *testing.T) {
qt.Check(t, qt.IsNil((*Spec)(nil).Copy()))

spec := parseELFBTF(t, "../testdata/loader-el.elf")
cpy := spec.Copy()

Expand Down
2 changes: 1 addition & 1 deletion collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func (cs *CollectionSpec) Copy() *CollectionSpec {
Maps: make(map[string]*MapSpec, len(cs.Maps)),
Programs: make(map[string]*ProgramSpec, len(cs.Programs)),
ByteOrder: cs.ByteOrder,
Types: cs.Types,
Types: cs.Types.Copy(),
}

for name, spec := range cs.Maps {
Expand Down
26 changes: 7 additions & 19 deletions collection_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ebpf

import (
"encoding/binary"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -57,15 +58,15 @@ func TestCollectionSpecNotModified(t *testing.T) {

func TestCollectionSpecCopy(t *testing.T) {
cs := &CollectionSpec{
Maps: map[string]*MapSpec{
map[string]*MapSpec{
"my-map": {
Type: Array,
KeySize: 4,
ValueSize: 4,
MaxEntries: 1,
},
},
Programs: map[string]*ProgramSpec{
map[string]*ProgramSpec{
"test": {
Type: SocketFilter,
Instructions: asm.Instructions{
Expand All @@ -76,25 +77,12 @@ func TestCollectionSpecCopy(t *testing.T) {
License: "MIT",
},
},
Types: &btf.Spec{},
&btf.Spec{},
binary.LittleEndian,
}
cpy := cs.Copy()

if cpy == cs {
t.Error("Copy returned the same pointer")
}

if cpy.Maps["my-map"] == cs.Maps["my-map"] {
t.Error("Copy returned same Maps")
}

if cpy.Programs["test"] == cs.Programs["test"] {
t.Error("Copy returned same Programs")
}

if cpy.Types != cs.Types {
t.Error("Copy returned different Types")
}
qt.Check(t, qt.IsNil((*CollectionSpec)(nil).Copy()))
qt.Assert(t, testutils.IsDeepCopy(cs.Copy(), cs))
}

func TestCollectionSpecLoadCopy(t *testing.T) {
Expand Down
158 changes: 158 additions & 0 deletions internal/testutils/checkers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
package testutils

import (
"bytes"
"fmt"
"reflect"

"github.com/go-quicktest/qt"
)

// IsDeepCopy checks that got is a deep copy of want.
//
// All primitive values must be equal, but pointers must be distinct.
// This is different from [reflect.DeepEqual] which will accept equal pointer values.
// That is, reflect.DeepEqual(a, a) is true, while IsDeepCopy(a, a) is false.
func IsDeepCopy[T any](got, want T) qt.Checker {
return &deepCopyChecker[T]{got, want, make(map[pair]struct{})}
}

type pair struct {
got, want reflect.Value
}

type deepCopyChecker[T any] struct {
got, want T
visited map[pair]struct{}
}

func (dcc *deepCopyChecker[T]) Check(_ func(key string, value any)) error {
return dcc.check(reflect.ValueOf(dcc.got), reflect.ValueOf(dcc.want))
}

func (dcc *deepCopyChecker[T]) check(got, want reflect.Value) error {
switch want.Kind() {
case reflect.Interface:
return dcc.check(got.Elem(), want.Elem())

case reflect.Pointer:
if got.IsNil() && want.IsNil() {
return nil
}

if got.IsNil() {
return fmt.Errorf("expected non-nil pointer")
}

if want.IsNil() {
return fmt.Errorf("expected nil pointer")
}

if got.UnsafePointer() == want.UnsafePointer() {
return fmt.Errorf("equal pointer values")
}

switch want.Type() {
case reflect.TypeOf((*bytes.Reader)(nil)):
// bytes.Reader doesn't allow modifying it's contents, so we
// allow a shallow copy.
return nil
}

if _, ok := dcc.visited[pair{got, want}]; ok {
// Deal with recursive types.
return nil
}

dcc.visited[pair{got, want}] = struct{}{}
return dcc.check(got.Elem(), want.Elem())

case reflect.Slice:
if got.IsNil() && want.IsNil() {
return nil
}

if got.IsNil() {
return fmt.Errorf("expected non-nil slice")
}

if want.IsNil() {
return fmt.Errorf("expected nil slice")
}

if got.Len() != want.Len() {
return fmt.Errorf("expected %d elements, got %d", want.Len(), got.Len())
}

if want.Len() == 0 {
return nil
}

if got.UnsafePointer() == want.UnsafePointer() {
return fmt.Errorf("equal backing memory")
}

fallthrough

case reflect.Array:
for i := 0; i < want.Len(); i++ {
if err := dcc.check(got.Index(i), want.Index(i)); err != nil {
return fmt.Errorf("index %d: %w", i, err)
}
}

return nil

case reflect.Struct:
for i := 0; i < want.NumField(); i++ {
if err := dcc.check(got.Field(i), want.Field(i)); err != nil {
return fmt.Errorf("%q: %w", want.Type().Field(i).Name, err)
}
}

return nil

case reflect.Map:
if got.Len() != want.Len() {
return fmt.Errorf("expected %d items, got %d", want.Len(), got.Len())
}

if got.UnsafePointer() == want.UnsafePointer() {
return fmt.Errorf("maps are equal")
}

iter := want.MapRange()
for iter.Next() {
key := iter.Key()
got := got.MapIndex(iter.Key())
if !got.IsValid() {
return fmt.Errorf("key %v is missing", key)
}

want := iter.Value()
if err := dcc.check(got, want); err != nil {
return fmt.Errorf("key %v: %w", key, err)
}
}

return nil

case reflect.Chan, reflect.UnsafePointer:
return fmt.Errorf("%s is not supported", want.Type())

default:
// Compare by value as usual.
if !got.Equal(want) {
return fmt.Errorf("%#v is not equal to %#v", got, want)
}

return nil
}
}

func (dcc *deepCopyChecker[T]) Args() []qt.Arg {
return []qt.Arg{
{Name: "got", Value: dcc.got},
{Name: "want", Value: dcc.want},
}
}
Loading