Skip to content

Commit 1de06c0

Browse files
committed
add CopyToGoFlagSet
This is useful for programs which want to define some flags with pflag (for example, in external packages) but still need to use Go flag command line parsing to preserve backward compatibility with previous releases, in particular support for single-dash flags.
1 parent c78f730 commit 1de06c0

File tree

2 files changed

+108
-0
lines changed

2 files changed

+108
-0
lines changed

golangflag.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
goflag "flag"
99
"reflect"
1010
"strings"
11+
"time"
1112
)
1213

1314
// go test flags prefixes
@@ -113,6 +114,38 @@ func (f *FlagSet) AddGoFlagSet(newSet *goflag.FlagSet) {
113114
f.addedGoFlagSets = append(f.addedGoFlagSets, newSet)
114115
}
115116

117+
// CopyToGoFlagSet will add all current flags to the given Go flag set.
118+
// Deprecation remarks get copied into the usage description.
119+
// Whenever possible, a flag gets added for which Go flags shows
120+
// a proper type in the help message.
121+
func (f *FlagSet) CopyToGoFlagSet(newSet *goflag.FlagSet) {
122+
f.VisitAll(func(flag *Flag) {
123+
usage := flag.Usage
124+
if flag.Deprecated != "" {
125+
usage += " (DEPRECATED: " + flag.Deprecated + ")"
126+
}
127+
128+
switch value := flag.Value.(type) {
129+
case *stringValue:
130+
newSet.StringVar((*string)(value), flag.Name, flag.DefValue, usage)
131+
case *intValue:
132+
newSet.IntVar((*int)(value), flag.Name, *(*int)(value), usage)
133+
case *int64Value:
134+
newSet.Int64Var((*int64)(value), flag.Name, *(*int64)(value), usage)
135+
case *uintValue:
136+
newSet.UintVar((*uint)(value), flag.Name, *(*uint)(value), usage)
137+
case *uint64Value:
138+
newSet.Uint64Var((*uint64)(value), flag.Name, *(*uint64)(value), usage)
139+
case *durationValue:
140+
newSet.DurationVar((*time.Duration)(value), flag.Name, *(*time.Duration)(value), usage)
141+
case *float64Value:
142+
newSet.Float64Var((*float64)(value), flag.Name, *(*float64)(value), usage)
143+
default:
144+
newSet.Var(flag.Value, flag.Name, usage)
145+
}
146+
})
147+
}
148+
116149
// ParseSkippedFlags explicitly Parses go test flags (i.e. the one starting with '-test.') with goflag.Parse(),
117150
// since by default those are skipped by pflag.Parse().
118151
// Typical usage example: `ParseGoTestFlags(os.Args[1:], goflag.CommandLine)`
@@ -125,3 +158,4 @@ func ParseSkippedFlags(osArgs []string, goFlagSet *goflag.FlagSet) error {
125158
}
126159
return goFlagSet.Parse(skippedFlags)
127160
}
161+

golangflag_test.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package pflag
77
import (
88
goflag "flag"
99
"testing"
10+
"time"
1011
)
1112

1213
func TestGoflags(t *testing.T) {
@@ -59,3 +60,76 @@ func TestGoflags(t *testing.T) {
5960
t.Fatal("goflag.CommandLine.Parsed() return false after f.Parse() called")
6061
}
6162
}
63+
64+
func TestToGoflags(t *testing.T) {
65+
pfs := FlagSet{}
66+
gfs := goflag.FlagSet{}
67+
pfs.String("StringFlag", "String value", "String flag usage")
68+
pfs.Int("IntFlag", 1, "Int flag usage")
69+
pfs.Uint("UintFlag", 2, "Uint flag usage")
70+
pfs.Int64("Int64Flag", 3, "Int64 flag usage")
71+
pfs.Uint64("Uint64Flag", 4, "Uint64 flag usage")
72+
pfs.Int8("Int8Flag", 5, "Int8 flag usage")
73+
pfs.Float64("Float64Flag", 6.0, "Float64 flag usage")
74+
pfs.Duration("DurationFlag", time.Second, "Duration flag usage")
75+
pfs.Bool("BoolFlag", true, "Bool flag usage")
76+
pfs.String("deprecated", "Deprecated value", "Deprecated flag usage")
77+
pfs.MarkDeprecated("deprecated", "obsolete")
78+
79+
pfs.CopyToGoFlagSet(&gfs)
80+
81+
// Modify via pfs. Should be visible via gfs because both share the
82+
// same values.
83+
for name, value := range map[string]string{
84+
"StringFlag": "Modified String value",
85+
"IntFlag": "11",
86+
"UintFlag": "12",
87+
"Int64Flag": "13",
88+
"Uint64Flag": "14",
89+
"Int8Flag": "15",
90+
"Float64Flag": "16.0",
91+
"BoolFlag": "false",
92+
} {
93+
pf := pfs.Lookup(name)
94+
if pf == nil {
95+
t.Errorf("%s: not found in pflag flag set", name)
96+
continue
97+
}
98+
if err := pf.Value.Set(value); err != nil {
99+
t.Errorf("error setting %s = %s: %v", name, value, err)
100+
}
101+
}
102+
103+
// Check that all flags were added and share the same value.
104+
pfs.VisitAll(func(pf *Flag) {
105+
gf := gfs.Lookup(pf.Name)
106+
if gf == nil {
107+
t.Errorf("%s: not found in Go flag set", pf.Name)
108+
return
109+
}
110+
if gf.Value.String() != pf.Value.String() {
111+
t.Errorf("%s: expected value %v from Go flag set, got %v",
112+
pf.Name, pf.Value, gf.Value)
113+
return
114+
}
115+
})
116+
117+
// Check for unexpected additional flags.
118+
gfs.VisitAll(func(gf *goflag.Flag) {
119+
pf := gfs.Lookup(gf.Name)
120+
if pf == nil {
121+
t.Errorf("%s: not found in pflag flag set", gf.Name)
122+
return
123+
}
124+
})
125+
126+
deprecated := gfs.Lookup("deprecated")
127+
if deprecated == nil {
128+
t.Error("deprecated: not found in Go flag set")
129+
} else {
130+
expectedUsage := "Deprecated flag usage (DEPRECATED: obsolete)"
131+
if deprecated.Usage != expectedUsage {
132+
t.Errorf("deprecation remark not added, expected usage %q, got %q", expectedUsage, deprecated.Usage)
133+
}
134+
}
135+
}

0 commit comments

Comments
 (0)