@@ -6,7 +6,10 @@ import (
66 "strings"
77)
88
9- type bindings map [reflect.Type ]func () (reflect.Value , error )
9+ // A map of type to function that returns a value of that type.
10+ //
11+ // The function should have the signature func(...) (T, error). Arguments are recursively resolved.
12+ type bindings map [reflect.Type ]any
1013
1114func (b bindings ) String () string {
1215 out := []string {}
@@ -19,32 +22,23 @@ func (b bindings) String() string {
1922func (b bindings ) add (values ... interface {}) bindings {
2023 for _ , v := range values {
2124 v := v
22- b [reflect .TypeOf (v )] = func () (reflect. Value , error ) { return reflect . ValueOf ( v ) , nil }
25+ b [reflect .TypeOf (v )] = func () (any , error ) { return v , nil }
2326 }
2427 return b
2528}
2629
2730func (b bindings ) addTo (impl , iface interface {}) {
28- valueOf := reflect .ValueOf (impl )
29- b [reflect .TypeOf (iface ).Elem ()] = func () (reflect.Value , error ) { return valueOf , nil }
31+ b [reflect .TypeOf (iface ).Elem ()] = func () (any , error ) { return impl , nil }
3032}
3133
3234func (b bindings ) addProvider (provider interface {}) error {
3335 pv := reflect .ValueOf (provider )
3436 t := pv .Type ()
35- if t .Kind () != reflect .Func || t .NumIn () != 0 || t . NumOut () != 2 || t .Out (1 ) != reflect .TypeOf ((* error )(nil )).Elem () {
36- return fmt .Errorf ("%T must be a function with the signature func()(T, error)" , provider )
37+ if t .Kind () != reflect .Func || t .NumOut () != 2 || t .Out (1 ) != reflect .TypeOf ((* error )(nil )).Elem () {
38+ return fmt .Errorf ("%T must be a function with the signature func(... )(T, error)" , provider )
3739 }
3840 rt := pv .Type ().Out (0 )
39- b [rt ] = func () (reflect.Value , error ) {
40- out := pv .Call (nil )
41- errv := out [1 ]
42- var err error
43- if ! errv .IsNil () {
44- err = errv .Interface ().(error ) //nolint
45- }
46- return out [0 ], err
47- }
41+ b [rt ] = provider
4842 return nil
4943}
5044
@@ -101,15 +95,19 @@ func callAnyFunction(f reflect.Value, bindings bindings) (out []any, err error)
10195 t := f .Type ()
10296 for i := 0 ; i < t .NumIn (); i ++ {
10397 pt := t .In (i )
104- if argf , ok := bindings [pt ]; ok {
105- argv , err := argf ()
106- if err != nil {
107- return nil , err
108- }
109- in = append (in , argv )
110- } else {
98+ argf , ok := bindings [pt ]
99+ if ! ok {
111100 return nil , fmt .Errorf ("couldn't find binding of type %s for parameter %d of %s(), use kong.Bind(%s)" , pt , i , t , pt )
112101 }
102+ // Recursively resolve binding functions.
103+ argv , err := callAnyFunction (reflect .ValueOf (argf ), bindings )
104+ if err != nil {
105+ return nil , fmt .Errorf ("%s: %w" , pt , err )
106+ }
107+ if ferrv := reflect .ValueOf (argv [len (argv )- 1 ]); ferrv .IsValid () && ! ferrv .IsNil () {
108+ return nil , ferrv .Interface ().(error ) //nolint:forcetypeassert
109+ }
110+ in = append (in , reflect .ValueOf (argv [0 ]))
113111 }
114112 outv := f .Call (in )
115113 out = make ([]any , len (outv ))
0 commit comments