Skip to content

Commit 7bbb0b7

Browse files
committed
feat: support recursive injection of provider parameters
This allows provider functions to accept parameters that are injected by other bindings or binding providers, eg. call the provider function with the root CLI struct (which is automatically bound by Kong): kong.BindToProvider(func(cli *CLI) (*Injected, error) { ... })
1 parent 373692a commit 7bbb0b7

File tree

3 files changed

+31
-24
lines changed

3 files changed

+31
-24
lines changed

callbacks.go

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@ import (
66
"strings"
77
)
88

9-
type bindings map[reflect.Type]func() (reflect.Value, error)
9+
// A map of type to function that returns a value of that type.
10+
//
11+
// The function should have the signature func(...) (T, error). Arguments are recursively resolved.
12+
type bindings map[reflect.Type]any
1013

1114
func (b bindings) String() string {
1215
out := []string{}
@@ -19,32 +22,23 @@ func (b bindings) String() string {
1922
func (b bindings) add(values ...interface{}) bindings {
2023
for _, v := range values {
2124
v := v
22-
b[reflect.TypeOf(v)] = func() (reflect.Value, error) { return reflect.ValueOf(v), nil }
25+
b[reflect.TypeOf(v)] = func() (any, error) { return v, nil }
2326
}
2427
return b
2528
}
2629

2730
func (b bindings) addTo(impl, iface interface{}) {
28-
valueOf := reflect.ValueOf(impl)
29-
b[reflect.TypeOf(iface).Elem()] = func() (reflect.Value, error) { return valueOf, nil }
31+
b[reflect.TypeOf(iface).Elem()] = func() (any, error) { return impl, nil }
3032
}
3133

3234
func (b bindings) addProvider(provider interface{}) error {
3335
pv := reflect.ValueOf(provider)
3436
t := pv.Type()
35-
if t.Kind() != reflect.Func || t.NumIn() != 0 || t.NumOut() != 2 || t.Out(1) != reflect.TypeOf((*error)(nil)).Elem() {
36-
return fmt.Errorf("%T must be a function with the signature func()(T, error)", provider)
37+
if t.Kind() != reflect.Func || t.NumOut() != 2 || t.Out(1) != reflect.TypeOf((*error)(nil)).Elem() {
38+
return fmt.Errorf("%T must be a function with the signature func(...)(T, error)", provider)
3739
}
3840
rt := pv.Type().Out(0)
39-
b[rt] = func() (reflect.Value, error) {
40-
out := pv.Call(nil)
41-
errv := out[1]
42-
var err error
43-
if !errv.IsNil() {
44-
err = errv.Interface().(error) //nolint
45-
}
46-
return out[0], err
47-
}
41+
b[rt] = provider
4842
return nil
4943
}
5044

@@ -101,15 +95,19 @@ func callAnyFunction(f reflect.Value, bindings bindings) (out []any, err error)
10195
t := f.Type()
10296
for i := 0; i < t.NumIn(); i++ {
10397
pt := t.In(i)
104-
if argf, ok := bindings[pt]; ok {
105-
argv, err := argf()
106-
if err != nil {
107-
return nil, err
108-
}
109-
in = append(in, argv)
110-
} else {
98+
argf, ok := bindings[pt]
99+
if !ok {
111100
return nil, fmt.Errorf("couldn't find binding of type %s for parameter %d of %s(), use kong.Bind(%s)", pt, i, t, pt)
112101
}
102+
// Recursively resolve binding functions.
103+
argv, err := callAnyFunction(reflect.ValueOf(argf), bindings)
104+
if err != nil {
105+
return nil, fmt.Errorf("%s: %w", pt, err)
106+
}
107+
if ferrv := reflect.ValueOf(argv[len(argv)-1]); ferrv.IsValid() && !ferrv.IsNil() {
108+
return nil, ferrv.Interface().(error) //nolint:forcetypeassert
109+
}
110+
in = append(in, reflect.ValueOf(argv[0]))
113111
}
114112
outv := f.Call(in)
115113
out = make([]any, len(outv))

options.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,11 @@ func BindTo(impl, iface interface{}) Option {
208208
})
209209
}
210210

211-
// BindToProvider allows binding of provider functions.
211+
// BindToProvider binds an injected value to a provider function.
212+
//
213+
// The provider function must have the signature:
214+
//
215+
// func() (interface{}, error)
212216
//
213217
// This is useful when the Run() function of different commands require different values that may
214218
// not all be initialisable from the main() function.

options_test.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,13 @@ func TestCallbackCustomError(t *testing.T) {
8989
}
9090

9191
type bindToProviderCLI struct {
92+
Filled bool `default:"true"`
9293
Called bool
9394
Cmd bindToProviderCmd `cmd:""`
9495
}
9596

9697
type boundThing struct {
98+
Filled bool
9799
}
98100

99101
type bindToProviderCmd struct{}
@@ -105,7 +107,10 @@ func (*bindToProviderCmd) Run(cli *bindToProviderCLI, b *boundThing) error {
105107

106108
func TestBindToProvider(t *testing.T) {
107109
var cli bindToProviderCLI
108-
app, err := New(&cli, BindToProvider(func() (*boundThing, error) { return &boundThing{}, nil }))
110+
app, err := New(&cli, BindToProvider(func(cli *bindToProviderCLI) (*boundThing, error) {
111+
assert.True(t, cli.Filled, "CLI struct should have already been populated by Kong")
112+
return &boundThing{Filled: cli.Filled}, nil
113+
}))
109114
assert.NoError(t, err)
110115
ctx, err := app.Parse([]string{"cmd"})
111116
assert.NoError(t, err)

0 commit comments

Comments
 (0)