Skip to content

Commit be7482b

Browse files
committed
fix a variety of Copy() problems
The are multiple Copy() methods in the code base which are used to create deep copies of a struct. Any bugs in them can manifest as data races in concurrent code. Add a quicktest checker which ensures that two variables are a deep copy of each other: values match, but all locations in memory differ. Fix a variety of problems with the existing Copy() implementations. They all allow copying a nil struct and copy all elements. There are exceptions to the deep copy rule: - MapSpec.Contents is not deep copied because we can't easily make copies of interface values. This is documented already. - MapSpec.Extra is an immutable bytes.Reader and therefore only needs a shallow copy. - ProgramSpec.AttachTarget is a Program, which is (currently) safe for concurrent use. Fixes #1517 Signed-off-by: Lorenz Bauer <[email protected]>
1 parent fbb9ed8 commit be7482b

9 files changed

Lines changed: 353 additions & 30 deletions

File tree

btf/btf.go

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ func (s *immutableTypes) typeByID(id TypeID) (Type, bool) {
6666
// mutableTypes is a set of types which may be changed.
6767
type mutableTypes struct {
6868
imm immutableTypes
69-
mu *sync.RWMutex // protects copies below
69+
mu sync.RWMutex // protects copies below
7070
copies map[Type]Type // map[orig]copy
7171
copiedTypeIDs map[Type]TypeID // map[copy]origID
7272
}
@@ -94,10 +94,14 @@ func (mt *mutableTypes) add(typ Type, typeIDs map[Type]TypeID) Type {
9494
}
9595

9696
// copy a set of mutable types.
97-
func (mt *mutableTypes) copy() mutableTypes {
98-
mtCopy := mutableTypes{
97+
func (mt *mutableTypes) copy() *mutableTypes {
98+
if mt == nil {
99+
return nil
100+
}
101+
102+
mtCopy := &mutableTypes{
99103
mt.imm,
100-
&sync.RWMutex{},
104+
sync.RWMutex{},
101105
make(map[Type]Type, len(mt.copies)),
102106
make(map[Type]TypeID, len(mt.copiedTypeIDs)),
103107
}
@@ -169,7 +173,7 @@ func (mt *mutableTypes) anyTypesByName(name string) ([]Type, error) {
169173
// Spec allows querying a set of Types and loading the set into the
170174
// kernel.
171175
type Spec struct {
172-
mutableTypes
176+
*mutableTypes
173177

174178
// String table from ELF.
175179
strings *stringTable
@@ -339,15 +343,15 @@ func loadRawSpec(btf io.ReaderAt, bo binary.ByteOrder, base *Spec) (*Spec, error
339343
typeIDs, typesByName := indexTypes(types, firstTypeID)
340344

341345
return &Spec{
342-
mutableTypes{
346+
&mutableTypes{
343347
immutableTypes{
344348
types,
345349
typeIDs,
346350
firstTypeID,
347351
typesByName,
348352
bo,
349353
},
350-
&sync.RWMutex{},
354+
sync.RWMutex{},
351355
make(map[Type]Type),
352356
make(map[Type]TypeID),
353357
},
@@ -522,6 +526,10 @@ func fixupDatasecLayout(ds *Datasec) error {
522526

523527
// Copy creates a copy of Spec.
524528
func (s *Spec) Copy() *Spec {
529+
if s == nil {
530+
return nil
531+
}
532+
525533
return &Spec{
526534
s.mutableTypes.copy(),
527535
s.strings,

btf/btf_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,8 @@ func TestGuessBTFByteOrder(t *testing.T) {
326326
}
327327

328328
func TestSpecCopy(t *testing.T) {
329+
qt.Check(t, qt.IsNil((*Spec)(nil).Copy()))
330+
329331
spec := parseELFBTF(t, "../testdata/loader-el.elf")
330332
cpy := spec.Copy()
331333

collection.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ func (cs *CollectionSpec) Copy() *CollectionSpec {
5757
Maps: make(map[string]*MapSpec, len(cs.Maps)),
5858
Programs: make(map[string]*ProgramSpec, len(cs.Programs)),
5959
ByteOrder: cs.ByteOrder,
60-
Types: cs.Types,
60+
Types: cs.Types.Copy(),
6161
}
6262

6363
for name, spec := range cs.Maps {

collection_test.go

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package ebpf
22

33
import (
4+
"encoding/binary"
45
"errors"
56
"fmt"
67
"io"
@@ -57,15 +58,15 @@ func TestCollectionSpecNotModified(t *testing.T) {
5758

5859
func TestCollectionSpecCopy(t *testing.T) {
5960
cs := &CollectionSpec{
60-
Maps: map[string]*MapSpec{
61+
map[string]*MapSpec{
6162
"my-map": {
6263
Type: Array,
6364
KeySize: 4,
6465
ValueSize: 4,
6566
MaxEntries: 1,
6667
},
6768
},
68-
Programs: map[string]*ProgramSpec{
69+
map[string]*ProgramSpec{
6970
"test": {
7071
Type: SocketFilter,
7172
Instructions: asm.Instructions{
@@ -76,25 +77,12 @@ func TestCollectionSpecCopy(t *testing.T) {
7677
License: "MIT",
7778
},
7879
},
79-
Types: &btf.Spec{},
80+
&btf.Spec{},
81+
binary.LittleEndian,
8082
}
81-
cpy := cs.Copy()
8283

83-
if cpy == cs {
84-
t.Error("Copy returned the same pointer")
85-
}
86-
87-
if cpy.Maps["my-map"] == cs.Maps["my-map"] {
88-
t.Error("Copy returned same Maps")
89-
}
90-
91-
if cpy.Programs["test"] == cs.Programs["test"] {
92-
t.Error("Copy returned same Programs")
93-
}
94-
95-
if cpy.Types != cs.Types {
96-
t.Error("Copy returned different Types")
97-
}
84+
qt.Check(t, qt.IsNil((*CollectionSpec)(nil).Copy()))
85+
qt.Assert(t, testutils.IsDeepCopy(cs.Copy(), cs))
9886
}
9987

10088
func TestCollectionSpecLoadCopy(t *testing.T) {

internal/testutils/checkers.go

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
package testutils
2+
3+
import (
4+
"bytes"
5+
"fmt"
6+
"reflect"
7+
8+
"github.com/go-quicktest/qt"
9+
)
10+
11+
// IsDeepCopy checks that got is a deep copy of want.
12+
//
13+
// All primitive values must be equal, but pointers must be distinct.
14+
// This is different from [reflect.DeepEqual] which will accept equal pointer values.
15+
// That is, reflect.DeepEqual(a, a) is true, while IsDeepCopy(a, a) is false.
16+
func IsDeepCopy[T any](got, want T) qt.Checker {
17+
return &deepCopyChecker[T]{got, want, make(map[pair]struct{})}
18+
}
19+
20+
type pair struct {
21+
got, want reflect.Value
22+
}
23+
24+
type deepCopyChecker[T any] struct {
25+
got, want T
26+
visited map[pair]struct{}
27+
}
28+
29+
func (dcc *deepCopyChecker[T]) Check(_ func(key string, value any)) error {
30+
return dcc.check(reflect.ValueOf(dcc.got), reflect.ValueOf(dcc.want))
31+
}
32+
33+
func (dcc *deepCopyChecker[T]) check(got, want reflect.Value) error {
34+
switch want.Kind() {
35+
case reflect.Interface:
36+
return dcc.check(got.Elem(), want.Elem())
37+
38+
case reflect.Pointer:
39+
if got.IsNil() && want.IsNil() {
40+
return nil
41+
}
42+
43+
if got.IsNil() {
44+
return fmt.Errorf("expected non-nil pointer")
45+
}
46+
47+
if want.IsNil() {
48+
return fmt.Errorf("expected nil pointer")
49+
}
50+
51+
if got.UnsafePointer() == want.UnsafePointer() {
52+
return fmt.Errorf("equal pointer values")
53+
}
54+
55+
switch want.Type() {
56+
case reflect.TypeOf((*bytes.Reader)(nil)):
57+
// bytes.Reader doesn't allow modifying it's contents, so we
58+
// allow a shallow copy.
59+
return nil
60+
}
61+
62+
if _, ok := dcc.visited[pair{got, want}]; ok {
63+
// Deal with recursive types.
64+
return nil
65+
}
66+
67+
dcc.visited[pair{got, want}] = struct{}{}
68+
return dcc.check(got.Elem(), want.Elem())
69+
70+
case reflect.Slice:
71+
if got.IsNil() && want.IsNil() {
72+
return nil
73+
}
74+
75+
if got.IsNil() {
76+
return fmt.Errorf("expected non-nil slice")
77+
}
78+
79+
if want.IsNil() {
80+
return fmt.Errorf("expected nil slice")
81+
}
82+
83+
if got.Len() != want.Len() {
84+
return fmt.Errorf("expected %d elements, got %d", want.Len(), got.Len())
85+
}
86+
87+
if want.Len() == 0 {
88+
return nil
89+
}
90+
91+
if got.UnsafePointer() == want.UnsafePointer() {
92+
return fmt.Errorf("equal backing memory")
93+
}
94+
95+
fallthrough
96+
97+
case reflect.Array:
98+
for i := 0; i < want.Len(); i++ {
99+
if err := dcc.check(got.Index(i), want.Index(i)); err != nil {
100+
return fmt.Errorf("index %d: %w", i, err)
101+
}
102+
}
103+
104+
return nil
105+
106+
case reflect.Struct:
107+
for i := 0; i < want.NumField(); i++ {
108+
if err := dcc.check(got.Field(i), want.Field(i)); err != nil {
109+
return fmt.Errorf("%q: %w", want.Type().Field(i).Name, err)
110+
}
111+
}
112+
113+
return nil
114+
115+
case reflect.Map:
116+
if got.Len() != want.Len() {
117+
return fmt.Errorf("expected %d items, got %d", want.Len(), got.Len())
118+
}
119+
120+
if got.UnsafePointer() == want.UnsafePointer() {
121+
return fmt.Errorf("maps are equal")
122+
}
123+
124+
iter := want.MapRange()
125+
for iter.Next() {
126+
key := iter.Key()
127+
got := got.MapIndex(iter.Key())
128+
if !got.IsValid() {
129+
return fmt.Errorf("key %v is missing", key)
130+
}
131+
132+
want := iter.Value()
133+
if err := dcc.check(got, want); err != nil {
134+
return fmt.Errorf("key %v: %w", key, err)
135+
}
136+
}
137+
138+
return nil
139+
140+
case reflect.Chan, reflect.UnsafePointer:
141+
return fmt.Errorf("%s is not supported", want.Type())
142+
143+
default:
144+
// Compare by value as usual.
145+
if !got.Equal(want) {
146+
return fmt.Errorf("%#v is not equal to %#v", got, want)
147+
}
148+
149+
return nil
150+
}
151+
}
152+
153+
func (dcc *deepCopyChecker[T]) Args() []qt.Arg {
154+
return []qt.Arg{
155+
{Name: "got", Value: dcc.got},
156+
{Name: "want", Value: dcc.want},
157+
}
158+
}

0 commit comments

Comments
 (0)