@@ -58,24 +58,39 @@ type decoder struct {
5858 vs []stack
5959}
6060
61- type stack []reflect.Value
61+ // stackEntry represents an entry in the decode stack with optional typeName for union types. When
62+ // typeName is non-nil, only the field matching that typename should be unmarshaled into. The 'key'
63+ // field tracks which GraphQL field this entry represents.
64+ type stackEntry struct {
65+ value reflect.Value
66+ // typeName is non-nil only for union fields which should match __typename
67+ typeName * string
68+ // key is the GraphQL field name/key which produced this stackEntry
69+ key string
70+ }
71+
72+ type stack []stackEntry
6273
63- func (s stack ) Top () reflect. Value {
74+ func (s stack ) Top () stackEntry {
6475 return s [len (s )- 1 ]
6576}
6677
6778func (s stack ) Pop () stack {
6879 return s [:len (s )- 1 ]
6980}
7081
82+ func (s stack ) TopValue () reflect.Value {
83+ return s [len (s )- 1 ].value
84+ }
85+
7186// Decode decodes a single JSON value from d.tokenizer into v.
7287func (d * decoder ) Decode (v interface {}) error {
7388 rv := reflect .ValueOf (v )
7489 if rv .Kind () != reflect .Ptr {
7590 return fmt .Errorf ("cannot decode into non-pointer %T" , v )
7691 }
7792
78- d .vs = []stack {{rv .Elem ()}}
93+ d .vs = []stack {{stackEntry { value : rv .Elem ()} }}
7994
8095 return d .decode ()
8196}
@@ -108,8 +123,13 @@ func (d *decoder) decode() error {
108123 // If one field is raw all must be treated as raw
109124 rawMessage := false
110125 isScalar := false
126+ currentKey := key
127+
128+ // First pass: find which stacks have this field
129+ fieldResults := make ([]reflect.Value , len (d .vs ))
111130 for i := range d .vs {
112- v := d .vs [i ].Top ()
131+ entry := d .vs [i ].Top ()
132+ v := entry .value
113133 for v .Kind () == reflect .Ptr || v .Kind () == reflect .Interface {
114134 v = v .Elem ()
115135 }
@@ -134,7 +154,7 @@ func (d *decoder) decode() error {
134154 default :
135155 }
136156
137- d . vs [i ] = append ( d . vs [ i ], f )
157+ fieldResults [i ] = f
138158 }
139159
140160 if ! someFieldExist {
@@ -145,6 +165,12 @@ func (d *decoder) decode() error {
145165 )
146166 }
147167
168+ // Second pass: append field results to stacks
169+ for i := range d .vs {
170+ f := fieldResults [i ]
171+ d .vs [i ] = append (d .vs [i ], stackEntry {value : f , key : currentKey })
172+ }
173+
148174 if rawMessage || isScalar {
149175 // Read the next complete object from the json stream
150176 var data json.RawMessage
@@ -169,7 +195,8 @@ func (d *decoder) decode() error {
169195 case d .state () == '[' && tok != json .Delim (']' ):
170196 someSliceExist := false
171197 for i := range d .vs {
172- v := d .vs [i ].Top ()
198+ entry := d .vs [i ].Top ()
199+ v := entry .value
173200 for v .Kind () == reflect .Ptr || v .Kind () == reflect .Interface {
174201 v = v .Elem ()
175202 }
@@ -191,7 +218,7 @@ func (d *decoder) decode() error {
191218 }
192219 }
193220
194- d .vs [i ] = append (d .vs [i ], f )
221+ d .vs [i ] = append (d .vs [i ], stackEntry { value : f , key : "" } )
195222 }
196223
197224 if ! someSliceExist {
@@ -203,8 +230,11 @@ func (d *decoder) decode() error {
203230 case string , json.Number , bool , nil , json.RawMessage :
204231 // Value.
205232
233+ destTypeName := ""
234+
206235 for i := range d .vs {
207- v := d .vs [i ].Top ()
236+ entry := d .vs [i ].Top ()
237+ v := entry .value
208238 if ! v .IsValid () {
209239 continue
210240 }
@@ -213,8 +243,23 @@ func (d *decoder) decode() error {
213243 if err != nil {
214244 return err
215245 }
246+
247+ // since we _technically_ are unmarshaling into the top of all valid stacks, leading
248+ // to duplication, we track the __typename so we can later do a delete of the union
249+ // fields which are not named in __typename. This approach was chosen as the path of
250+ // least disruption of this decode() function
251+ if entry .key == "__typename" {
252+ if strVal , ok := tok .(string ); ok {
253+ destTypeName = strVal
254+ }
255+ }
216256 }
217257 d .popAllVs ()
258+
259+ // After popping, if we just unmarshaled __typename, filter union fields
260+ if destTypeName != "" {
261+ d .filterUnionFieldsByTypeName (destTypeName )
262+ }
218263 case json.Delim :
219264 switch tok {
220265 case '{' :
@@ -225,7 +270,8 @@ func (d *decoder) decode() error {
225270 frontier := make ([]reflect.Value , len (d .vs )) // Places to look for GraphQL fragments/embedded structs.
226271
227272 for i := range d .vs {
228- v := d .vs [i ].Top ()
273+ entry := d .vs [i ].Top ()
274+ v := entry .value
229275 frontier [i ] = v
230276 // TODO: Do this recursively or not? Add a test case if needed.
231277 if v .Kind () == reflect .Ptr && v .IsNil () {
@@ -245,10 +291,20 @@ func (d *decoder) decode() error {
245291
246292 if v .Kind () == reflect .Struct {
247293 for i := 0 ; i < v .NumField (); i ++ {
248- if isGraphQLFragment (v .Type ().Field (i )) || v .Type ().Field (i ).Anonymous {
294+ field := v .Type ().Field (i )
295+ if isGraphQLFragment (field ) || field .Anonymous {
249296 // Add GraphQL fragment or embedded struct.
250- d .vs = append (d .vs , []reflect.Value {v .Field (i )})
251- frontier = append (frontier , v .Field (i ))
297+ fieldVal := v .Field (i )
298+ var typeName * string
299+ if isGraphQLFragment (field ) {
300+ typeName = extractUnionFieldTypeName (field .Tag .Get ("graphql" ))
301+ // Initialize nil pointers in union fields too
302+ if fieldVal .Kind () == reflect .Ptr && fieldVal .IsNil () {
303+ fieldVal .Set (reflect .New (fieldVal .Type ().Elem ()))
304+ }
305+ }
306+ d .vs = append (d .vs , []stackEntry {{value : fieldVal , typeName : typeName }})
307+ frontier = append (frontier , fieldVal )
252308 }
253309 }
254310 } else if isOrderedMap (v ) {
@@ -258,7 +314,9 @@ func (d *decoder) decode() error {
258314
259315 if keyForGraphQLFragment (key .Interface ().(string )) {
260316 // Add GraphQL fragment or embedded struct.
261- d .vs = append (d .vs , []reflect.Value {val })
317+ keyStr := key .Interface ().(string )
318+ typeName := extractUnionFieldTypeName (keyStr )
319+ d .vs = append (d .vs , []stackEntry {{value : val , typeName : typeName }})
262320 frontier = append (frontier , val )
263321 }
264322 }
@@ -270,7 +328,8 @@ func (d *decoder) decode() error {
270328 d .pushState (tok )
271329
272330 for i := range d .vs {
273- v := d .vs [i ].Top ()
331+ entry := d .vs [i ].Top ()
332+ v := entry .value
274333 // TODO: Confirm this is needed, write a test case.
275334 // if v.Kind() == reflect.Ptr && v.IsNil() {
276335 // v.Set(reflect.New(v.Type().Elem())) // v = new(T).
@@ -396,14 +455,71 @@ func (d *decoder) popAllVs() {
396455// popLeftArrayTemplates pops left from last array items of all d.vs stacks.
397456func (d * decoder ) popLeftArrayTemplates () {
398457 for i := range d .vs {
399- v := d .vs [i ].Top ()
458+ entry := d .vs [i ].Top ()
459+ v := entry .value
400460
401461 if v .IsValid () {
402462 v .Set (v .Slice (1 , v .Len ()))
403463 }
404464 }
405465}
406466
467+ // extractUnionFieldTypeName extracts the typename from a GraphQL fragment tag like "... on ClosedEvent".
468+ // Returns a pointer to the type name if it's a valid union field, nil otherwise.
469+ func extractUnionFieldTypeName (tag string ) * string {
470+ tag = strings .TrimSpace (tag )
471+ if ! strings .HasPrefix (tag , "..." ) {
472+ return nil
473+ }
474+
475+ // Extract the type name after "on"
476+ parts := strings .Fields (tag )
477+ if len (parts ) >= 3 && parts [1 ] == "on" {
478+ typeName := strings .TrimSpace (strings .Join (parts [2 :], " " ))
479+ // Cut off anything after the type name (like field arguments)
480+ if i := strings .IndexAny (typeName , "(:@" ); i != - 1 {
481+ typeName = typeName [:i ]
482+ }
483+ typeName = strings .TrimSpace (typeName )
484+ if typeName != "" {
485+ return & typeName
486+ }
487+ }
488+ return nil
489+ }
490+
491+ // filterUnionFieldsByTypeName filters d.vs to remove union fields that don't match the given typename.
492+ // This is called after __typename is unmarshaled so we only unmarshal into the correct union variant.
493+ // We keep stacks that don't have a typeName filter, and only keep union fields that match.
494+ // For pointer union fields which don't match, we set them back to nil.
495+ func (d * decoder ) filterUnionFieldsByTypeName (typeName string ) {
496+ var filtered []stack
497+
498+ for _ , st := range d .vs {
499+ if len (st ) == 0 {
500+ filtered = append (filtered , st )
501+ continue
502+ }
503+
504+ entry := st [len (st )- 1 ]
505+ if entry .typeName == nil {
506+ // Not a union field (like the parent struct), keep it
507+ filtered = append (filtered , st )
508+ } else if * entry .typeName == typeName {
509+ // Union field which matches the typename, keep it
510+ filtered = append (filtered , st )
511+ } else {
512+ // Union field doesn't match - set it to nil if it's a pointer
513+ v := entry .value
514+ if v .Kind () == reflect .Ptr && v .CanSet () {
515+ v .Set (reflect .Zero (v .Type ()))
516+ }
517+ }
518+ }
519+
520+ d .vs = filtered
521+ }
522+
407523// fieldByGraphQLName returns an exported struct field of struct v
408524// that matches GraphQL name, or invalid reflect.Value if none found.
409525func fieldByGraphQLName (v reflect.Value , name string ) (reflect.Value , bool ) {
0 commit comments