Skip to content

Commit 6d00909

Browse files
authored
Pass context to completion (#1265)
1 parent 7223a99 commit 6d00909

File tree

4 files changed

+94
-1
lines changed

4 files changed

+94
-1
lines changed

command.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -887,7 +887,8 @@ func (c *Command) preRun() {
887887
}
888888

889889
// ExecuteContext is the same as Execute(), but sets the ctx on the command.
890-
// Retrieve ctx by calling cmd.Context() inside your *Run lifecycle functions.
890+
// Retrieve ctx by calling cmd.Context() inside your *Run lifecycle or ValidArgs
891+
// functions.
891892
func (c *Command) ExecuteContext(ctx context.Context) error {
892893
c.ctx = ctx
893894
return c.Execute()
@@ -901,6 +902,14 @@ func (c *Command) Execute() error {
901902
return err
902903
}
903904

905+
// ExecuteContextC is the same as ExecuteC(), but sets the ctx on the command.
906+
// Retrieve ctx by calling cmd.Context() inside your *Run lifecycle or ValidArgs
907+
// functions.
908+
func (c *Command) ExecuteContextC(ctx context.Context) (*Command, error) {
909+
c.ctx = ctx
910+
return c.ExecuteC()
911+
}
912+
904913
// ExecuteC executes the command.
905914
func (c *Command) ExecuteC() (cmd *Command, err error) {
906915
if c.ctx == nil {

command_test.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,17 @@ func executeCommandC(root *Command, args ...string) (c *Command, output string,
4242
return c, buf.String(), err
4343
}
4444

45+
func executeCommandWithContextC(ctx context.Context, root *Command, args ...string) (c *Command, output string, err error) {
46+
buf := new(bytes.Buffer)
47+
root.SetOut(buf)
48+
root.SetErr(buf)
49+
root.SetArgs(args)
50+
51+
c, err = root.ExecuteContextC(ctx)
52+
53+
return c, buf.String(), err
54+
}
55+
4556
func resetCommandLineFlagSet() {
4657
pflag.CommandLine = pflag.NewFlagSet(os.Args[0], pflag.ExitOnError)
4758
}
@@ -178,6 +189,35 @@ func TestExecuteContext(t *testing.T) {
178189
}
179190
}
180191

192+
func TestExecuteContextC(t *testing.T) {
193+
ctx := context.TODO()
194+
195+
ctxRun := func(cmd *Command, args []string) {
196+
if cmd.Context() != ctx {
197+
t.Errorf("Command %q must have context when called with ExecuteContext", cmd.Use)
198+
}
199+
}
200+
201+
rootCmd := &Command{Use: "root", Run: ctxRun, PreRun: ctxRun}
202+
childCmd := &Command{Use: "child", Run: ctxRun, PreRun: ctxRun}
203+
granchildCmd := &Command{Use: "grandchild", Run: ctxRun, PreRun: ctxRun}
204+
205+
childCmd.AddCommand(granchildCmd)
206+
rootCmd.AddCommand(childCmd)
207+
208+
if _, _, err := executeCommandWithContextC(ctx, rootCmd, ""); err != nil {
209+
t.Errorf("Root command must not fail: %+v", err)
210+
}
211+
212+
if _, _, err := executeCommandWithContextC(ctx, rootCmd, "child"); err != nil {
213+
t.Errorf("Subcommand must not fail: %+v", err)
214+
}
215+
216+
if _, _, err := executeCommandWithContextC(ctx, rootCmd, "child", "grandchild"); err != nil {
217+
t.Errorf("Command child must not fail: %+v", err)
218+
}
219+
}
220+
181221
func TestExecute_NoContext(t *testing.T) {
182222
run := func(cmd *Command, args []string) {
183223
if cmd.Context() != context.Background() {

completions.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ func (c *Command) getCompletions(args []string) (*Command, []string, ShellCompDi
221221
// Unable to find the real command. E.g., <program> someInvalidCmd <TAB>
222222
return c, []string{}, ShellCompDirectiveDefault, fmt.Errorf("Unable to find a command for arguments: %v", trimmedArgs)
223223
}
224+
finalCmd.ctx = c.ctx
224225

225226
// Check if we are doing flag value completion before parsing the flags.
226227
// This is important because if we are completing a flag value, we need to also

completions_test.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package cobra
22

33
import (
44
"bytes"
5+
"context"
56
"strings"
67
"testing"
78
)
@@ -1203,6 +1204,48 @@ func TestFlagDirFilterCompletionInGo(t *testing.T) {
12031204
}
12041205
}
12051206

1207+
func TestValidArgsFuncCmdContext(t *testing.T) {
1208+
validArgsFunc := func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) {
1209+
ctx := cmd.Context()
1210+
1211+
if ctx == nil {
1212+
t.Error("Received nil context in completion func")
1213+
} else if ctx.Value("testKey") != "123" {
1214+
t.Error("Received invalid context")
1215+
}
1216+
1217+
return nil, ShellCompDirectiveDefault
1218+
}
1219+
1220+
rootCmd := &Command{
1221+
Use: "root",
1222+
Run: emptyRun,
1223+
}
1224+
childCmd := &Command{
1225+
Use: "childCmd",
1226+
ValidArgsFunction: validArgsFunc,
1227+
Run: emptyRun,
1228+
}
1229+
rootCmd.AddCommand(childCmd)
1230+
1231+
//nolint:golint,staticcheck // We can safely use a basic type as key in tests.
1232+
ctx := context.WithValue(context.Background(), "testKey", "123")
1233+
1234+
// Test completing an empty string on the childCmd
1235+
_, output, err := executeCommandWithContextC(ctx, rootCmd, ShellCompNoDescRequestCmd, "childCmd", "")
1236+
if err != nil {
1237+
t.Errorf("Unexpected error: %v", err)
1238+
}
1239+
1240+
expected := strings.Join([]string{
1241+
":0",
1242+
"Completion ended with directive: ShellCompDirectiveDefault", ""}, "\n")
1243+
1244+
if output != expected {
1245+
t.Errorf("expected: %q, got: %q", expected, output)
1246+
}
1247+
}
1248+
12061249
func TestValidArgsFuncSingleCmd(t *testing.T) {
12071250
rootCmd := &Command{
12081251
Use: "root",

0 commit comments

Comments
 (0)