Skip to content

Commit 956e5de

Browse files
committed
spf13#199 with some modifications
1 parent e43c76f commit 956e5de

File tree

2 files changed

+70
-5
lines changed

2 files changed

+70
-5
lines changed

flag.go

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,8 @@ const (
122122

123123
// ParseErrorsWhitelist defines the parsing errors that can be ignored
124124
type ParseErrorsWhitelist struct {
125-
// UnknownFlags will ignore unknown flags errors and continue parsing rest of the flags
125+
// UnknownFlags will ignore unknown flags errors and continue parsing the rest of the flags.
126+
// Consider using SetUnknownFlags/GetUnknownFlags if you need to know which unknown flags occured.
126127
UnknownFlags bool
127128
}
128129

@@ -162,6 +163,7 @@ type FlagSet struct {
162163
output io.Writer // nil means stderr; use Output() accessor
163164
interspersed bool // allow interspersed option/non-option args
164165
normalizeNameFunc func(f *FlagSet, name string) NormalizedName
166+
unknownFlags *[]string
165167

166168
addedGoFlagSets []*goflag.FlagSet
167169
}
@@ -964,10 +966,17 @@ func (f *FlagSet) usage() {
964966
}
965967
}
966968

969+
func (f *FlagSet) addUnknownFlag(s string) {
970+
if f.unknownFlags == nil {
971+
f.unknownFlags = new([]string)
972+
}
973+
*f.unknownFlags = append(*f.unknownFlags, s)
974+
}
975+
967976
//--unknown (args will be empty)
968977
//--unknown --next-flag ... (args will be --next-flag ...)
969978
//--unknown arg ... (args will be arg ...)
970-
func stripUnknownFlagValue(args []string) []string {
979+
func (f *FlagSet) stripUnknownFlagValue(args []string) []string {
971980
if len(args) == 0 {
972981
//--unknown
973982
return args
@@ -981,6 +990,7 @@ func stripUnknownFlagValue(args []string) []string {
981990

982991
//--unknown arg ... (args will be arg ...)
983992
if len(args) > 1 {
993+
f.addUnknownFlag(args[0])
984994
return args[1:]
985995
}
986996
return nil
@@ -1007,13 +1017,14 @@ func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []strin
10071017
}
10081018
return
10091019
case f.ParseErrorsWhitelist.UnknownFlags:
1020+
f.addUnknownFlag(s)
10101021
// --unknown=unknownval arg ...
10111022
// we do not want to lose arg in this case
10121023
if len(split) >= 2 {
10131024
return a, nil
10141025
}
10151026

1016-
return stripUnknownFlagValue(a), nil
1027+
return f.stripUnknownFlagValue(a), nil
10171028
default:
10181029
err = f.failf("unknown flag: --%s", name)
10191030
return
@@ -1063,11 +1074,15 @@ func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parse
10631074
// '-f=arg arg ...'
10641075
// we do not want to lose arg in this case
10651076
if len(shorthands) > 2 && shorthands[1] == '=' {
1077+
f.addUnknownFlag("-" + shorthands)
10661078
outShorts = ""
10671079
return
10681080
}
10691081

1070-
outArgs = stripUnknownFlagValue(outArgs)
1082+
f.addUnknownFlag("-" + string(c))
1083+
if len(outShorts) == 0 {
1084+
outArgs = f.stripUnknownFlagValue(outArgs)
1085+
}
10711086
return
10721087
default:
10731088
err = f.failf("unknown shorthand flag: %q in -%s", c, shorthands)
@@ -1223,6 +1238,21 @@ func (f *FlagSet) Parsed() bool {
12231238
return f.parsed
12241239
}
12251240

1241+
// SetUnknownFlags sets the store for unknown flags found during Parse.
1242+
// The argument s points to a slice variable in which to store the values.
1243+
// This requires ParseErrorsWhitelist.UnknownFlags to be set so that
1244+
// parsing does not abort on the first unknown flag.
1245+
func (f *FlagSet) SetUnknownFlags(s *[]string) {
1246+
f.unknownFlags = s
1247+
}
1248+
1249+
// GetUnknownFlags returns unknown flags found during Parse.
1250+
// This requires ParseErrorsWhitelist.UnknownFlags to be set so that
1251+
// parsing does not abort on the first unknown flag.
1252+
func (f *FlagSet) GetUnknownFlags() *[]string {
1253+
return f.unknownFlags
1254+
}
1255+
12261256
// Parse parses the command-line flags from os.Args[1:]. Must be called
12271257
// after all flags are defined and before flags are accessed by the program.
12281258
func Parse() {
@@ -1248,6 +1278,21 @@ func Parsed() bool {
12481278
return CommandLine.Parsed()
12491279
}
12501280

1281+
// SetUnknownFlags sets the store for unknown flags found during Parse.
1282+
// The argument s points to a slice variable in which to store the values.
1283+
// This requires ParseErrorsWhitelist.UnknownFlags to be set so that
1284+
// parsing does not abort on the first unknown flag.
1285+
func SetUnknownFlags(s *[]string) {
1286+
CommandLine.SetUnknownFlags(s)
1287+
}
1288+
1289+
// GetUnknownFlags returns unknown flags found during Parse.
1290+
// This requires ParseErrorsWhitelist.UnknownFlags to be set so that
1291+
// parsing does not abort on the first unknown flag.
1292+
func GetUnknownFlags() *[]string {
1293+
return CommandLine.GetUnknownFlags()
1294+
}
1295+
12511296
// CommandLine is the default set of command-line flags, parsed from os.Args.
12521297
var CommandLine = NewFlagSet(os.Args[0], ExitOnError)
12531298

flag_test.go

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,8 @@ func testParseWithUnknownFlags(f *FlagSet, t *testing.T) {
405405
t.Error("f.Parse() = true before Parse")
406406
}
407407
f.ParseErrorsWhitelist.UnknownFlags = true
408+
var unknownFlags []string
409+
f.SetUnknownFlags(&unknownFlags)
408410

409411
f.BoolP("boola", "a", false, "bool value")
410412
f.BoolP("boolb", "b", false, "bool2 value")
@@ -455,6 +457,19 @@ func testParseWithUnknownFlags(f *FlagSet, t *testing.T) {
455457
"stringo", "ovalue",
456458
"boole", "true",
457459
}
460+
wantUnknowns := []string{
461+
"--unknown1", "unknown1Value",
462+
"--unknown2=unknown2Value",
463+
"-u=unknown3Value",
464+
"-p", "unknown4Value",
465+
"-q",
466+
"--unknown7=unknown7value",
467+
"--unknown8=unknown8value",
468+
"--unknown6", "",
469+
"-u", "-u", "-u", "-u", "-u", "",
470+
"--unknown10",
471+
"--unknown11",
472+
}
458473
got := []string{}
459474
store := func(flag *Flag, value string) error {
460475
got = append(got, flag.Name)
@@ -470,10 +485,15 @@ func testParseWithUnknownFlags(f *FlagSet, t *testing.T) {
470485
t.Errorf("f.Parse() = false after Parse")
471486
}
472487
if !reflect.DeepEqual(got, want) {
473-
t.Errorf("f.ParseAll() fail to restore the args")
488+
t.Errorf("f.Parse() failed to parse with unknown flags")
474489
t.Errorf("Got: %v", got)
475490
t.Errorf("Want: %v", want)
476491
}
492+
if !reflect.DeepEqual(unknownFlags, wantUnknowns) {
493+
t.Errorf("f.Parse() failed to enumerate the unknown flags")
494+
t.Errorf("Got: %v", unknownFlags)
495+
t.Errorf("Want: %v", wantUnknowns)
496+
}
477497
}
478498

479499
func TestShorthand(t *testing.T) {

0 commit comments

Comments
 (0)