diff --git a/cmd/bpf2go/gen/output.go b/cmd/bpf2go/gen/output.go index fbda8df62..56ab312c7 100644 --- a/cmd/bpf2go/gen/output.go +++ b/cmd/bpf2go/gen/output.go @@ -169,7 +169,7 @@ func Generate(args GenerateArgs) error { if err != nil { return fmt.Errorf("generating %s: %w", name, err) } - _, ok := typ.(*btf.Struct) + _, ok := btf.As[*btf.Struct](typ) needsStructsPkg = needsStructsPkg || ok typeDecls = append(typeDecls, decl) } diff --git a/cmd/bpf2go/gen/output_test.go b/cmd/bpf2go/gen/output_test.go index 3859087fd..539620c26 100644 --- a/cmd/bpf2go/gen/output_test.go +++ b/cmd/bpf2go/gen/output_test.go @@ -118,3 +118,56 @@ func TestObjects(t *testing.T) { qt.Assert(t, qt.StringContains(str, "Var1 *ebpf.Variable `ebpf:\"var_1\"`")) qt.Assert(t, qt.StringContains(str, "ProgFoo1 *ebpf.Program `ebpf:\"prog_foo_1\"`")) } + +func TestGenerateStructTypes(t *testing.T) { + ts := &btf.Struct{ + Name: "test_struct", + Size: 8, + Members: []btf.Member{ + { + Name: "field1", + Type: &btf.Int{Size: 8, Encoding: btf.Unsigned}, + Offset: 0, + }, + }, + } + td := &btf.Typedef{ + Name: "test_typedef", + Type: ts, + } + + tests := []struct { + name string + types []btf.Type + expected string + }{ + { + name: "simple struct", + types: []btf.Type{ts}, + expected: "type stemTestStruct struct {\n\t_ structs.HostLayout\n\tField1 uint64\n}", + }, + { + name: "typedef struct", + types: []btf.Type{td}, + expected: "type stemTestTypedef struct {\n\t_ structs.HostLayout\n\tField1 uint64\n}", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var buf bytes.Buffer + err := Generate(GenerateArgs{ + Package: "test", + Stem: "stem", + Types: tt.types, + Output: &buf, + Constraints: nil, + }) + qt.Assert(t, qt.IsNil(err)) + + str := buf.String() + qt.Assert(t, qt.StringContains(str, tt.expected)) + qt.Assert(t, qt.StringContains(str, "\"structs\"")) + }) + } +}