Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions golangflag.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
goflag "flag"
"reflect"
"strings"
"time"
)

// go test flags prefixes
Expand Down Expand Up @@ -113,6 +114,38 @@ func (f *FlagSet) AddGoFlagSet(newSet *goflag.FlagSet) {
f.addedGoFlagSets = append(f.addedGoFlagSets, newSet)
}

// CopyToGoFlagSet will add all current flags to the given Go flag set.
// Deprecation remarks get copied into the usage description.
// Whenever possible, a flag gets added for which Go flags shows
// a proper type in the help message.
func (f *FlagSet) CopyToGoFlagSet(newSet *goflag.FlagSet) {
f.VisitAll(func(flag *Flag) {
usage := flag.Usage
if flag.Deprecated != "" {
usage += " (DEPRECATED: " + flag.Deprecated + ")"
}

switch value := flag.Value.(type) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks a bit fragile; how do we ensure it keeps working when new value types are added to the project (as has been done a few times since this PR was opened, e.g. #348, and will probably be done again, e.g. #359)? Is there a better way to register these things with flag without having to know the exact type of the thing?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This list only needs to include types which have a counterpart in the standard flag package. Everything else is handled through the generic default case. In other words, the list is complete until the standard package gets extended, and even then it's not breaking.

case *stringValue:
newSet.StringVar((*string)(value), flag.Name, flag.DefValue, usage)
case *intValue:
newSet.IntVar((*int)(value), flag.Name, *(*int)(value), usage)
case *int64Value:
newSet.Int64Var((*int64)(value), flag.Name, *(*int64)(value), usage)
case *uintValue:
newSet.UintVar((*uint)(value), flag.Name, *(*uint)(value), usage)
case *uint64Value:
newSet.Uint64Var((*uint64)(value), flag.Name, *(*uint64)(value), usage)
case *durationValue:
newSet.DurationVar((*time.Duration)(value), flag.Name, *(*time.Duration)(value), usage)
case *float64Value:
newSet.Float64Var((*float64)(value), flag.Name, *(*float64)(value), usage)
default:
newSet.Var(flag.Value, flag.Name, usage)
}
})
}

// ParseSkippedFlags explicitly Parses go test flags (i.e. the one starting with '-test.') with goflag.Parse(),
// since by default those are skipped by pflag.Parse().
// Typical usage example: `ParseGoTestFlags(os.Args[1:], goflag.CommandLine)`
Expand All @@ -125,3 +158,4 @@ func ParseSkippedFlags(osArgs []string, goFlagSet *goflag.FlagSet) error {
}
return goFlagSet.Parse(skippedFlags)
}

74 changes: 74 additions & 0 deletions golangflag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package pflag
import (
goflag "flag"
"testing"
"time"
)

func TestGoflags(t *testing.T) {
Expand Down Expand Up @@ -59,3 +60,76 @@ func TestGoflags(t *testing.T) {
t.Fatal("goflag.CommandLine.Parsed() return false after f.Parse() called")
}
}

func TestToGoflags(t *testing.T) {
pfs := FlagSet{}
gfs := goflag.FlagSet{}
pfs.String("StringFlag", "String value", "String flag usage")
pfs.Int("IntFlag", 1, "Int flag usage")
pfs.Uint("UintFlag", 2, "Uint flag usage")
pfs.Int64("Int64Flag", 3, "Int64 flag usage")
pfs.Uint64("Uint64Flag", 4, "Uint64 flag usage")
pfs.Int8("Int8Flag", 5, "Int8 flag usage")
pfs.Float64("Float64Flag", 6.0, "Float64 flag usage")
pfs.Duration("DurationFlag", time.Second, "Duration flag usage")
pfs.Bool("BoolFlag", true, "Bool flag usage")
pfs.String("deprecated", "Deprecated value", "Deprecated flag usage")
pfs.MarkDeprecated("deprecated", "obsolete")

pfs.CopyToGoFlagSet(&gfs)

// Modify via pfs. Should be visible via gfs because both share the
// same values.
for name, value := range map[string]string{
"StringFlag": "Modified String value",
"IntFlag": "11",
"UintFlag": "12",
"Int64Flag": "13",
"Uint64Flag": "14",
"Int8Flag": "15",
"Float64Flag": "16.0",
"BoolFlag": "false",
} {
pf := pfs.Lookup(name)
if pf == nil {
t.Errorf("%s: not found in pflag flag set", name)
continue
}
if err := pf.Value.Set(value); err != nil {
t.Errorf("error setting %s = %s: %v", name, value, err)
}
}

// Check that all flags were added and share the same value.
pfs.VisitAll(func(pf *Flag) {
gf := gfs.Lookup(pf.Name)
if gf == nil {
t.Errorf("%s: not found in Go flag set", pf.Name)
return
}
if gf.Value.String() != pf.Value.String() {
t.Errorf("%s: expected value %v from Go flag set, got %v",
pf.Name, pf.Value, gf.Value)
return
}
})

// Check for unexpected additional flags.
gfs.VisitAll(func(gf *goflag.Flag) {
pf := gfs.Lookup(gf.Name)
if pf == nil {
t.Errorf("%s: not found in pflag flag set", gf.Name)
return
}
})

deprecated := gfs.Lookup("deprecated")
if deprecated == nil {
t.Error("deprecated: not found in Go flag set")
} else {
expectedUsage := "Deprecated flag usage (DEPRECATED: obsolete)"
if deprecated.Usage != expectedUsage {
t.Errorf("deprecation remark not added, expected usage %q, got %q", expectedUsage, deprecated.Usage)
}
}
}
Loading