Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 149 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
package pflag

import "fmt"

// notExistErrorMessageType specifies which flavor of "flag does not exist"
// is printed by NotExistError. This allows the related errors to be grouped
// under a single NotExistError struct without making a breaking change to
// the error message text.
type notExistErrorMessageType int

const (
flagNotExistMessage notExistErrorMessageType = iota
flagNotDefinedMessage
flagNoSuchFlagMessage
flagUnknownFlagMessage
flagUnknownShorthandFlagMessage
)

// NotExistError is the error returned when trying to access a flag that
// does not exist in the FlagSet.
type NotExistError struct {
name string
specifiedShorthands string
messageType notExistErrorMessageType
}

// Error implements error.
func (e *NotExistError) Error() string {
switch e.messageType {
case flagNotExistMessage:
return fmt.Sprintf("flag %q does not exist", e.name)

case flagNotDefinedMessage:
return fmt.Sprintf("flag accessed but not defined: %s", e.name)

case flagNoSuchFlagMessage:
return fmt.Sprintf("no such flag -%v", e.name)

case flagUnknownFlagMessage:
return fmt.Sprintf("unknown flag: --%s", e.name)

case flagUnknownShorthandFlagMessage:
c := rune(e.name[0])
return fmt.Sprintf("unknown shorthand flag: %q in -%s", c, e.specifiedShorthands)
}

panic(fmt.Errorf("unknown flagNotExistErrorMessageType: %v", e.messageType))
}

// GetSpecifiedName returns the name of the flag (without dashes) as it
// appeared in the parsed arguments.
func (e *NotExistError) GetSpecifiedName() string {
return e.name
}

// GetSpecifiedShortnames returns the group of shorthand arguments
// (without dashes) that the flag appeared within. If the flag was not in a
// shorthand group, this will return an empty string.
func (e *NotExistError) GetSpecifiedShortnames() string {
return e.specifiedShorthands
}

// ValueRequiredError is the error returned when a flag needs an argument but
// no argument was provided.
type ValueRequiredError struct {
flag *Flag
specifiedName string
specifiedShorthands string
}

// Error implements error.
func (e *ValueRequiredError) Error() string {
if len(e.specifiedShorthands) > 0 {
c := rune(e.specifiedName[0])
return fmt.Sprintf("flag needs an argument: %q in -%s", c, e.specifiedShorthands)
}

return fmt.Sprintf("flag needs an argument: --%s", e.specifiedName)
}

// GetFlag returns the flag for which the error occurred.
func (e *ValueRequiredError) GetFlag() *Flag {
return e.flag
}

// GetSpecifiedName returns the name of the flag (without dashes) as it
// appeared in the parsed arguments.
func (e *ValueRequiredError) GetSpecifiedName() string {
return e.specifiedName
}

// GetSpecifiedShortnames returns the group of shorthand arguments
// (without dashes) that the flag appeared within. If the flag was not in a
// shorthand group, this will return an empty string.
func (e *ValueRequiredError) GetSpecifiedShortnames() string {
return e.specifiedShorthands
}

// InvalidValueError is the error returned when an invalid value is used
// for a flag.
type InvalidValueError struct {
flag *Flag
value string
cause error
}

// Error implements error.
func (e *InvalidValueError) Error() string {
flag := e.flag
var flagName string
if flag.Shorthand != "" && flag.ShorthandDeprecated == "" {
flagName = fmt.Sprintf("-%s, --%s", flag.Shorthand, flag.Name)
} else {
flagName = fmt.Sprintf("--%s", flag.Name)
}
return fmt.Sprintf("invalid argument %q for %q flag: %v", e.value, flagName, e.cause)
}

// Unwrap implements errors.Unwrap.
func (e *InvalidValueError) Unwrap() error {
return e.cause
}

// GetFlag returns the flag for which the error occurred.
func (e *InvalidValueError) GetFlag() *Flag {
return e.flag
}

// GetValue returns the invalid value that was provided.
func (e *InvalidValueError) GetValue() string {
return e.value
}

// InvalidSyntaxError is the error returned when a bad flag name is passed on
// the command line.
type InvalidSyntaxError struct {
specifiedFlag string
}

// Error implements error.
func (e *InvalidSyntaxError) Error() string {
return fmt.Sprintf("bad flag syntax: %s", e.specifiedFlag)
}

// GetSpecifiedName returns the exact flag (with dashes) as it
// appeared in the parsed arguments.
func (e *InvalidSyntaxError) GetSpecifiedFlag() string {
return e.specifiedFlag
}
67 changes: 67 additions & 0 deletions errors_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package pflag

import (
"errors"
"testing"
)

func TestNotExistError(t *testing.T) {
err := &NotExistError{
name: "foo",
specifiedShorthands: "bar",
}

if err.GetSpecifiedName() != "foo" {
t.Errorf("Expected GetSpecifiedName to return %q, got %q", "foo", err.GetSpecifiedName())
}
if err.GetSpecifiedShortnames() != "bar" {
t.Errorf("Expected GetSpecifiedShortnames to return %q, got %q", "bar", err.GetSpecifiedShortnames())
}
}

func TestValueRequiredError(t *testing.T) {
err := &ValueRequiredError{
flag: &Flag{},
specifiedName: "foo",
specifiedShorthands: "bar",
}

if err.GetFlag() == nil {
t.Error("Expected GetSpecifiedName to return its flag field, but got nil")
}
if err.GetSpecifiedName() != "foo" {
t.Errorf("Expected GetSpecifiedName to return %q, got %q", "foo", err.GetSpecifiedName())
}
if err.GetSpecifiedShortnames() != "bar" {
t.Errorf("Expected GetSpecifiedShortnames to return %q, got %q", "bar", err.GetSpecifiedShortnames())
}
}

func TestInvalidValueError(t *testing.T) {
expectedCause := errors.New("error")
err := &InvalidValueError{
flag: &Flag{},
value: "foo",
cause: expectedCause,
}

if err.GetFlag() == nil {
t.Error("Expected GetSpecifiedName to return its flag field, but got nil")
}
if err.GetValue() != "foo" {
t.Errorf("Expected GetValue to return %q, got %q", "foo", err.GetValue())
}
if err.Unwrap() != expectedCause {
t.Errorf("Expected Unwrwap to return %q, got %q", expectedCause, err.Unwrap())
}
}

func TestInvalidSyntaxError(t *testing.T) {
err := &InvalidSyntaxError{
specifiedFlag: "--=",
}

if err.GetSpecifiedFlag() != "--=" {
t.Errorf("Expected GetSpecifiedFlag to return %q, got %q", "--=", err.GetSpecifiedFlag())
}
}
52 changes: 30 additions & 22 deletions flag.go
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ func (f *FlagSet) lookup(name NormalizedName) *Flag {
func (f *FlagSet) getFlagType(name string, ftype string, convFunc func(sval string) (interface{}, error)) (interface{}, error) {
flag := f.Lookup(name)
if flag == nil {
err := fmt.Errorf("flag accessed but not defined: %s", name)
err := &NotExistError{name: name, messageType: flagNotDefinedMessage}
return nil, err
}

Expand Down Expand Up @@ -411,7 +411,7 @@ func (f *FlagSet) ArgsLenAtDash() int {
func (f *FlagSet) MarkDeprecated(name string, usageMessage string) error {
flag := f.Lookup(name)
if flag == nil {
return fmt.Errorf("flag %q does not exist", name)
return &NotExistError{name: name, messageType: flagNotExistMessage}
}
if usageMessage == "" {
return fmt.Errorf("deprecated message for flag %q must be set", name)
Expand All @@ -427,7 +427,7 @@ func (f *FlagSet) MarkDeprecated(name string, usageMessage string) error {
func (f *FlagSet) MarkShorthandDeprecated(name string, usageMessage string) error {
flag := f.Lookup(name)
if flag == nil {
return fmt.Errorf("flag %q does not exist", name)
return &NotExistError{name: name, messageType: flagNotExistMessage}
}
if usageMessage == "" {
return fmt.Errorf("deprecated message for flag %q must be set", name)
Expand All @@ -441,7 +441,7 @@ func (f *FlagSet) MarkShorthandDeprecated(name string, usageMessage string) erro
func (f *FlagSet) MarkHidden(name string) error {
flag := f.Lookup(name)
if flag == nil {
return fmt.Errorf("flag %q does not exist", name)
return &NotExistError{name: name, messageType: flagNotExistMessage}
}
flag.Hidden = true
return nil
Expand All @@ -464,18 +464,16 @@ func (f *FlagSet) Set(name, value string) error {
normalName := f.normalizeFlagName(name)
flag, ok := f.formal[normalName]
if !ok {
return fmt.Errorf("no such flag -%v", name)
return &NotExistError{name: name, messageType: flagNoSuchFlagMessage}
}

err := flag.Value.Set(value)
if err != nil {
var flagName string
if flag.Shorthand != "" && flag.ShorthandDeprecated == "" {
flagName = fmt.Sprintf("-%s, --%s", flag.Shorthand, flag.Name)
} else {
flagName = fmt.Sprintf("--%s", flag.Name)
return &InvalidValueError{
flag: flag,
value: value,
cause: err,
}
return fmt.Errorf("invalid argument %q for %q flag: %v", value, flagName, err)
}

if !flag.Changed {
Expand All @@ -501,7 +499,7 @@ func (f *FlagSet) SetAnnotation(name, key string, values []string) error {
normalName := f.normalizeFlagName(name)
flag, ok := f.formal[normalName]
if !ok {
return fmt.Errorf("no such flag -%v", name)
return &NotExistError{name: name, messageType: flagNoSuchFlagMessage}
}
if flag.Annotations == nil {
flag.Annotations = map[string][]string{}
Expand Down Expand Up @@ -911,10 +909,9 @@ func VarP(value Value, name, shorthand, usage string) {
CommandLine.VarP(value, name, shorthand, usage)
}

// failf prints to standard error a formatted error and usage message and
// fail prints an error message and usage message to standard error and
// returns the error.
func (f *FlagSet) failf(format string, a ...interface{}) error {
err := fmt.Errorf(format, a...)
func (f *FlagSet) fail(err error) error {
if f.errorHandling != ContinueOnError {
fmt.Fprintln(f.Output(), err)
f.usage()
Expand Down Expand Up @@ -960,7 +957,7 @@ func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []strin
a = args
name := s[2:]
if len(name) == 0 || name[0] == '-' || name[0] == '=' {
err = f.failf("bad flag syntax: %s", s)
err = f.fail(&InvalidSyntaxError{specifiedFlag: s})
return
}

Expand All @@ -982,7 +979,7 @@ func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []strin

return stripUnknownFlagValue(a), nil
default:
err = f.failf("unknown flag: --%s", name)
err = f.fail(&NotExistError{name: name, messageType: flagUnknownFlagMessage})
return
}
}
Expand All @@ -1000,13 +997,16 @@ func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []strin
a = a[1:]
} else {
// '--flag' (arg was required)
err = f.failf("flag needs an argument: %s", s)
err = f.fail(&ValueRequiredError{
flag: flag,
specifiedName: name,
})
return
}

err = fn(flag, value)
if err != nil {
f.failf(err.Error())
f.fail(err)
}
return
}
Expand Down Expand Up @@ -1039,7 +1039,11 @@ func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parse
outArgs = stripUnknownFlagValue(outArgs)
return
default:
err = f.failf("unknown shorthand flag: %q in -%s", c, shorthands)
err = f.fail(&NotExistError{
name: string(c),
specifiedShorthands: shorthands,
messageType: flagUnknownShorthandFlagMessage,
})
return
}
}
Expand All @@ -1062,7 +1066,11 @@ func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parse
outArgs = args[1:]
} else {
// '-f' (arg was required)
err = f.failf("flag needs an argument: %q in -%s", c, shorthands)
err = f.fail(&ValueRequiredError{
flag: flag,
specifiedName: string(c),
specifiedShorthands: shorthands,
})
return
}

Expand All @@ -1072,7 +1080,7 @@ func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parse

err = fn(flag, value)
if err != nil {
f.failf(err.Error())
f.fail(err)
}
return
}
Expand Down
Loading
Loading