Skip to content

Commit 5d2e076

Browse files
lelandbateyhgiasac
andauthored
fix: Union type handling - duplication, pointers (#184)
Co-authored-by: Toan Nguyen <[email protected]>
1 parent cf6ae82 commit 5d2e076

File tree

2 files changed

+198
-28
lines changed

2 files changed

+198
-28
lines changed

pkg/jsonutil/graphql.go

Lines changed: 131 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -58,24 +58,39 @@ type decoder struct {
5858
vs []stack
5959
}
6060

61-
type stack []reflect.Value
61+
// stackEntry represents an entry in the decode stack with optional typeName for union types. When
62+
// typeName is non-nil, only the field matching that typename should be unmarshaled into. The 'key'
63+
// field tracks which GraphQL field this entry represents.
64+
type stackEntry struct {
65+
value reflect.Value
66+
// typeName is non-nil only for union fields which should match __typename
67+
typeName *string
68+
// key is the GraphQL field name/key which produced this stackEntry
69+
key string
70+
}
71+
72+
type stack []stackEntry
6273

63-
func (s stack) Top() reflect.Value {
74+
func (s stack) Top() stackEntry {
6475
return s[len(s)-1]
6576
}
6677

6778
func (s stack) Pop() stack {
6879
return s[:len(s)-1]
6980
}
7081

82+
func (s stack) TopValue() reflect.Value {
83+
return s[len(s)-1].value
84+
}
85+
7186
// Decode decodes a single JSON value from d.tokenizer into v.
7287
func (d *decoder) Decode(v interface{}) error {
7388
rv := reflect.ValueOf(v)
7489
if rv.Kind() != reflect.Ptr {
7590
return fmt.Errorf("cannot decode into non-pointer %T", v)
7691
}
7792

78-
d.vs = []stack{{rv.Elem()}}
93+
d.vs = []stack{{stackEntry{value: rv.Elem()}}}
7994

8095
return d.decode()
8196
}
@@ -108,8 +123,13 @@ func (d *decoder) decode() error {
108123
// If one field is raw all must be treated as raw
109124
rawMessage := false
110125
isScalar := false
126+
currentKey := key
127+
128+
// First pass: find which stacks have this field
129+
fieldResults := make([]reflect.Value, len(d.vs))
111130
for i := range d.vs {
112-
v := d.vs[i].Top()
131+
entry := d.vs[i].Top()
132+
v := entry.value
113133
for v.Kind() == reflect.Ptr || v.Kind() == reflect.Interface {
114134
v = v.Elem()
115135
}
@@ -134,7 +154,7 @@ func (d *decoder) decode() error {
134154
default:
135155
}
136156

137-
d.vs[i] = append(d.vs[i], f)
157+
fieldResults[i] = f
138158
}
139159

140160
if !someFieldExist {
@@ -145,6 +165,12 @@ func (d *decoder) decode() error {
145165
)
146166
}
147167

168+
// Second pass: append field results to stacks
169+
for i := range d.vs {
170+
f := fieldResults[i]
171+
d.vs[i] = append(d.vs[i], stackEntry{value: f, key: currentKey})
172+
}
173+
148174
if rawMessage || isScalar {
149175
// Read the next complete object from the json stream
150176
var data json.RawMessage
@@ -169,7 +195,8 @@ func (d *decoder) decode() error {
169195
case d.state() == '[' && tok != json.Delim(']'):
170196
someSliceExist := false
171197
for i := range d.vs {
172-
v := d.vs[i].Top()
198+
entry := d.vs[i].Top()
199+
v := entry.value
173200
for v.Kind() == reflect.Ptr || v.Kind() == reflect.Interface {
174201
v = v.Elem()
175202
}
@@ -191,7 +218,7 @@ func (d *decoder) decode() error {
191218
}
192219
}
193220

194-
d.vs[i] = append(d.vs[i], f)
221+
d.vs[i] = append(d.vs[i], stackEntry{value: f, key: ""})
195222
}
196223

197224
if !someSliceExist {
@@ -203,8 +230,11 @@ func (d *decoder) decode() error {
203230
case string, json.Number, bool, nil, json.RawMessage:
204231
// Value.
205232

233+
destTypeName := ""
234+
206235
for i := range d.vs {
207-
v := d.vs[i].Top()
236+
entry := d.vs[i].Top()
237+
v := entry.value
208238
if !v.IsValid() {
209239
continue
210240
}
@@ -213,8 +243,23 @@ func (d *decoder) decode() error {
213243
if err != nil {
214244
return err
215245
}
246+
247+
// since we _technically_ are unmarshaling into the top of all valid stacks, leading
248+
// to duplication, we track the __typename so we can later do a delete of the union
249+
// fields which are not named in __typename. This approach was chosen as the path of
250+
// least disruption of this decode() function
251+
if entry.key == "__typename" {
252+
if strVal, ok := tok.(string); ok {
253+
destTypeName = strVal
254+
}
255+
}
216256
}
217257
d.popAllVs()
258+
259+
// After popping, if we just unmarshaled __typename, filter union fields
260+
if destTypeName != "" {
261+
d.filterUnionFieldsByTypeName(destTypeName)
262+
}
218263
case json.Delim:
219264
switch tok {
220265
case '{':
@@ -225,7 +270,8 @@ func (d *decoder) decode() error {
225270
frontier := make([]reflect.Value, len(d.vs)) // Places to look for GraphQL fragments/embedded structs.
226271

227272
for i := range d.vs {
228-
v := d.vs[i].Top()
273+
entry := d.vs[i].Top()
274+
v := entry.value
229275
frontier[i] = v
230276
// TODO: Do this recursively or not? Add a test case if needed.
231277
if v.Kind() == reflect.Ptr && v.IsNil() {
@@ -245,10 +291,20 @@ func (d *decoder) decode() error {
245291

246292
if v.Kind() == reflect.Struct {
247293
for i := 0; i < v.NumField(); i++ {
248-
if isGraphQLFragment(v.Type().Field(i)) || v.Type().Field(i).Anonymous {
294+
field := v.Type().Field(i)
295+
if isGraphQLFragment(field) || field.Anonymous {
249296
// Add GraphQL fragment or embedded struct.
250-
d.vs = append(d.vs, []reflect.Value{v.Field(i)})
251-
frontier = append(frontier, v.Field(i))
297+
fieldVal := v.Field(i)
298+
var typeName *string
299+
if isGraphQLFragment(field) {
300+
typeName = extractUnionFieldTypeName(field.Tag.Get("graphql"))
301+
// Initialize nil pointers in union fields too
302+
if fieldVal.Kind() == reflect.Ptr && fieldVal.IsNil() {
303+
fieldVal.Set(reflect.New(fieldVal.Type().Elem()))
304+
}
305+
}
306+
d.vs = append(d.vs, []stackEntry{{value: fieldVal, typeName: typeName}})
307+
frontier = append(frontier, fieldVal)
252308
}
253309
}
254310
} else if isOrderedMap(v) {
@@ -258,7 +314,9 @@ func (d *decoder) decode() error {
258314

259315
if keyForGraphQLFragment(key.Interface().(string)) {
260316
// Add GraphQL fragment or embedded struct.
261-
d.vs = append(d.vs, []reflect.Value{val})
317+
keyStr := key.Interface().(string)
318+
typeName := extractUnionFieldTypeName(keyStr)
319+
d.vs = append(d.vs, []stackEntry{{value: val, typeName: typeName}})
262320
frontier = append(frontier, val)
263321
}
264322
}
@@ -270,7 +328,8 @@ func (d *decoder) decode() error {
270328
d.pushState(tok)
271329

272330
for i := range d.vs {
273-
v := d.vs[i].Top()
331+
entry := d.vs[i].Top()
332+
v := entry.value
274333
// TODO: Confirm this is needed, write a test case.
275334
// if v.Kind() == reflect.Ptr && v.IsNil() {
276335
// v.Set(reflect.New(v.Type().Elem())) // v = new(T).
@@ -396,14 +455,71 @@ func (d *decoder) popAllVs() {
396455
// popLeftArrayTemplates pops left from last array items of all d.vs stacks.
397456
func (d *decoder) popLeftArrayTemplates() {
398457
for i := range d.vs {
399-
v := d.vs[i].Top()
458+
entry := d.vs[i].Top()
459+
v := entry.value
400460

401461
if v.IsValid() {
402462
v.Set(v.Slice(1, v.Len()))
403463
}
404464
}
405465
}
406466

467+
// extractUnionFieldTypeName extracts the typename from a GraphQL fragment tag like "... on ClosedEvent".
468+
// Returns a pointer to the type name if it's a valid union field, nil otherwise.
469+
func extractUnionFieldTypeName(tag string) *string {
470+
tag = strings.TrimSpace(tag)
471+
if !strings.HasPrefix(tag, "...") {
472+
return nil
473+
}
474+
475+
// Extract the type name after "on"
476+
parts := strings.Fields(tag)
477+
if len(parts) >= 3 && parts[1] == "on" {
478+
typeName := strings.TrimSpace(strings.Join(parts[2:], " "))
479+
// Cut off anything after the type name (like field arguments)
480+
if i := strings.IndexAny(typeName, "(:@"); i != -1 {
481+
typeName = typeName[:i]
482+
}
483+
typeName = strings.TrimSpace(typeName)
484+
if typeName != "" {
485+
return &typeName
486+
}
487+
}
488+
return nil
489+
}
490+
491+
// filterUnionFieldsByTypeName filters d.vs to remove union fields that don't match the given typename.
492+
// This is called after __typename is unmarshaled so we only unmarshal into the correct union variant.
493+
// We keep stacks that don't have a typeName filter, and only keep union fields that match.
494+
// For pointer union fields which don't match, we set them back to nil.
495+
func (d *decoder) filterUnionFieldsByTypeName(typeName string) {
496+
var filtered []stack
497+
498+
for _, st := range d.vs {
499+
if len(st) == 0 {
500+
filtered = append(filtered, st)
501+
continue
502+
}
503+
504+
entry := st[len(st)-1]
505+
if entry.typeName == nil {
506+
// Not a union field (like the parent struct), keep it
507+
filtered = append(filtered, st)
508+
} else if *entry.typeName == typeName {
509+
// Union field which matches the typename, keep it
510+
filtered = append(filtered, st)
511+
} else {
512+
// Union field doesn't match - set it to nil if it's a pointer
513+
v := entry.value
514+
if v.Kind() == reflect.Ptr && v.CanSet() {
515+
v.Set(reflect.Zero(v.Type()))
516+
}
517+
}
518+
}
519+
520+
d.vs = filtered
521+
}
522+
407523
// fieldByGraphQLName returns an exported struct field of struct v
408524
// that matches GraphQL name, or invalid reflect.Value if none found.
409525
func fieldByGraphQLName(v reflect.Value, name string) (reflect.Value, bool) {

pkg/jsonutil/graphql_test.go

Lines changed: 67 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -532,12 +532,66 @@ func TestUnmarshalGraphQL_union(t *testing.T) {
532532
},
533533
CreatedAt: time.Unix(1498709521, 0).UTC(),
534534
},
535-
ReopenedEvent: reopenedEvent{
535+
// ReopenedEvent should be an empty struct because the named type in the message
536+
// (`__typename`) which should be filled in is `ClosedEvent`, so ReopenedEvent should not
537+
// be populated.
538+
ReopenedEvent: reopenedEvent{},
539+
}
540+
if !reflect.DeepEqual(got, want) {
541+
t.Error("not equal")
542+
}
543+
}
544+
545+
func TestUnmarshalGraphQL_unionpointers(t *testing.T) {
546+
/*
547+
{
548+
__typename
549+
... on ClosedEvent {
550+
createdAt
551+
actor {login}
552+
}
553+
... on ReopenedEvent {
554+
createdAt
555+
actor {login}
556+
}
557+
}
558+
*/
559+
type actor struct{ Login string }
560+
type closedEvent struct {
561+
Actor actor
562+
CreatedAt time.Time
563+
}
564+
type reopenedEvent struct {
565+
Actor actor
566+
CreatedAt time.Time
567+
}
568+
type issueTimelineItem struct {
569+
Typename string `graphql:"__typename"`
570+
ClosedEvent *closedEvent `graphql:"... on ClosedEvent"`
571+
ReopenedEvent *reopenedEvent `graphql:"... on ReopenedEvent"`
572+
}
573+
var got issueTimelineItem
574+
err := jsonutil.UnmarshalGraphQL([]byte(`{
575+
"__typename": "ClosedEvent",
576+
"createdAt": "2017-06-29T04:12:01Z",
577+
"actor": {
578+
"login": "shurcooL-test"
579+
}
580+
}`), &got)
581+
if err != nil {
582+
t.Fatal(err)
583+
}
584+
want := issueTimelineItem{
585+
Typename: "ClosedEvent",
586+
ClosedEvent: &closedEvent{
536587
Actor: actor{
537588
Login: "shurcooL-test",
538589
},
539590
CreatedAt: time.Unix(1498709521, 0).UTC(),
540591
},
592+
// ReopenedEvent should be nil because the named type in the message (`__typename`) which
593+
// should be filled in is `ClosedEvent`, so ReopenedEvent should not be populated.
594+
ReopenedEvent: nil,
541595
}
542596
if !reflect.DeepEqual(got, want) {
543597
t.Error("not equal")
@@ -558,10 +612,10 @@ func TestUnmarshalGraphQL_orderedMapUnion(t *testing.T) {
558612
}
559613
}
560614
*/
561-
actor := [][2]interface{}{{"login", ""}}
562-
closedEvent := [][2]interface{}{{"actor", actor}, {"createdAt", time.Time{}}}
563-
reopenedEvent := [][2]interface{}{{"actor", actor}, {"createdAt", time.Time{}}}
564-
got := [][2]interface{}{
615+
actor := [][2]any{{"login", ""}}
616+
closedEvent := [][2]any{{"actor", actor}, {"createdAt", time.Time{}}}
617+
reopenedEvent := [][2]any{{"actor", actor}, {"createdAt", time.Time{}}}
618+
got := [][2]any{
565619
{"__typename", ""},
566620
{"... on ClosedEvent", closedEvent},
567621
{"... on ReopenedEvent", reopenedEvent},
@@ -576,20 +630,20 @@ func TestUnmarshalGraphQL_orderedMapUnion(t *testing.T) {
576630
if err != nil {
577631
t.Fatal(err)
578632
}
579-
want := [][2]interface{}{
633+
want := [][2]any{
580634
{"__typename", "ClosedEvent"},
581-
{"... on ClosedEvent", [][2]interface{}{
582-
{"actor", [][2]interface{}{{"login", "shurcooL-test"}}},
583-
{"createdAt", time.Unix(1498709521, 0).UTC()},
584-
}},
585-
{"... on ReopenedEvent", [][2]interface{}{
586-
{"actor", [][2]interface{}{{"login", "shurcooL-test"}}},
635+
{"... on ClosedEvent", [][2]any{
636+
{"actor", [][2]any{{"login", "shurcooL-test"}}},
587637
{"createdAt", time.Unix(1498709521, 0).UTC()},
588638
}},
639+
// ReopenedEvent is expected to be the "empty" reopenedEvent because the __typename
640+
// indicates that only ClosedEvent should be "filled in". This causes ReopenedEvent to be
641+
// ignored and not changed
642+
{"... on ReopenedEvent", reopenedEvent},
589643
}
590644
if !reflect.DeepEqual(got, want) {
591645
t.Errorf("not equal:\ngot: %v\nwant: %v", got, want)
592-
createdAt := got[1][1].([][2]interface{})[1]
646+
createdAt := got[1][1].([][2]any)[1]
593647
t.Logf("key: %s, type: %v", createdAt[0], reflect.TypeOf(createdAt[1]))
594648
}
595649
}

0 commit comments

Comments
 (0)