Skip to content

Commit 599be04

Browse files
samiam2013bincyber
andcommitted
Add -typederrors flag for typed enum conversion errors
Co-authored-by: @bincyber <[email protected]>
1 parent 750eb57 commit 599be04

File tree

11 files changed

+253
-59
lines changed

11 files changed

+253
-59
lines changed

.github/workflows/go.yml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,31 @@ jobs:
2626

2727
- name: Test
2828
run: go test -v ./...
29+
30+
golden-tests:
31+
runs-on: ubuntu-latest
32+
needs: build
33+
steps:
34+
- uses: actions/checkout@v4
35+
36+
- name: Set up Go
37+
uses: actions/setup-go@v4
38+
with:
39+
go-version: '1.22'
40+
41+
- name: Run Golden Tests
42+
run: go test -v -run TestGolden
43+
44+
end-to-end-tests:
45+
runs-on: ubuntu-latest
46+
needs: build
47+
steps:
48+
- uses: actions/checkout@v4
49+
50+
- name: Set up Go
51+
uses: actions/setup-go@v4
52+
with:
53+
go-version: '1.22'
54+
55+
- name: Run End-to-End Tests
56+
run: go test -v -run TestEndToEnd

README.md

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ Flags:
3737
transform each item name by removing a prefix or comma separated list of prefixes. Default: ""
3838
-type string
3939
comma-separated list of type names; must be set
40+
-typederrors
41+
if true, errors from enumerrs/ will be errors.Join()-ed for errors.Is(...) to simplify invalid value handling. Default: false
4042
-values
4143
if true, alternative string values method will be generated. Default: false
4244
-yaml
@@ -70,6 +72,9 @@ When Enumer is applied to a type, it will generate:
7072
the enum conform to the `gopkg.in/yaml.v2.Marshaler` and `gopkg.in/yaml.v2.Unmarshaler` interfaces.
7173
- When the flag `sql` is provided, the methods for implementing the `Scanner` and `Valuer` interfaces.
7274
Useful when storing the enum in a database.
75+
- When the flag `typederrors` is provided, the string conversion functions will return errors wrapped with
76+
`errors.Join()` containing a typed error from the `enumerrs` package. This allows you to use `errors.Is()` to
77+
check for specific enum validation failures.
7378

7479

7580
For example, if we have an enum type called `Pill`,
@@ -200,7 +205,7 @@ For a module-aware repo with `enumer` in the `go.mod` file, generation can be ca
200205
//go:generate go run github.com/dmarkham/enumer -type=YOURTYPE
201206
```
202207

203-
There are four boolean flags: `json`, `text`, `yaml` and `sql`. You can use any combination of them (i.e. `enumer -type=Pill -json -text`),
208+
There are five boolean flags: `json`, `text`, `yaml`, `sql`, and `typederrors`. You can use any combination of them (i.e. `enumer -type=Pill -json -text -typederrors`),
204209

205210
For enum string representation transformation the `transform` and `trimprefix` flags
206211
were added (i.e. `enumer -type=MyType -json -transform=snake`).
@@ -215,6 +220,28 @@ If a prefix is provided via the `addprefix` flag, it will be added to the start
215220

216221
The boolean flag `values` will additionally create an alternative string values method `Values() []string` to fullfill the `EnumValues` interface of [ent](https://entgo.io/docs/schema-fields/#enum-fields).
217222

223+
## Typed Error Handling
224+
225+
When using the `typederrors` flag, you can handle enum validation errors specifically using `errors.Is()`:
226+
227+
```go
228+
import (
229+
"errors"
230+
"github.com/dmarkham/enumer/enumerrs"
231+
)
232+
233+
// This will return a typed error that can be checked
234+
pill, err := PillString("InvalidValue")
235+
if err != nil {
236+
if errors.Is(err, enumerrs.ErrValueInvalid) {
237+
// Handle invalid enum value specifically
238+
fmt.Println("Invalid pill value provided")
239+
}
240+
// The error also contains a descriptive message
241+
fmt.Printf("Error: %v\n", err)
242+
}
243+
```
244+
218245
## Inspiring projects
219246

220247
- [Álvaro López Espinosa](https://github.com/alvaroloes/enumer)

endtoend_test.go

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
// go command is not available on android
66

7+
//go:build !android
78
// +build !android
89

910
package main
@@ -75,6 +76,7 @@ func TestEndToEnd(t *testing.T) {
7576
// Names are known to be ASCII and long enough.
7677
var typeName string
7778
var transformNameMethod string
79+
var useTypedErrors bool
7880

7981
switch name {
8082
case "transform_snake.go":
@@ -110,18 +112,22 @@ func TestEndToEnd(t *testing.T) {
110112
case "transform_whitespace.go":
111113
typeName = "WhitespaceSeparatedValue"
112114
transformNameMethod = "whitespace"
115+
case "typedErrors.go":
116+
typeName = "TypedErrorsValue"
117+
transformNameMethod = "noop"
118+
useTypedErrors = true
113119
default:
114120
typeName = fmt.Sprintf("%c%s", name[0]+'A'-'a', name[1:len(name)-len(".go")])
115121
transformNameMethod = "noop"
116122
}
117123

118-
stringerCompileAndRun(t, dir, stringer, typeName, name, transformNameMethod)
124+
stringerCompileAndRun(t, dir, stringer, typeName, name, transformNameMethod, useTypedErrors)
119125
}
120126
}
121127

122128
// stringerCompileAndRun runs stringer for the named file and compiles and
123129
// runs the target binary in directory dir. That binary will panic if the String method is incorrect.
124-
func stringerCompileAndRun(t *testing.T, dir, stringer, typeName, fileName, transformNameMethod string) {
130+
func stringerCompileAndRun(t *testing.T, dir, stringer, typeName, fileName, transformNameMethod string, useTypedErrors bool) {
125131
t.Logf("run: %s %s\n", fileName, typeName)
126132
source := filepath.Join(dir, fileName)
127133
err := copy(source, filepath.Join("testdata", fileName))
@@ -130,7 +136,12 @@ func stringerCompileAndRun(t *testing.T, dir, stringer, typeName, fileName, tran
130136
}
131137
stringSource := filepath.Join(dir, typeName+"_string.go")
132138
// Run stringer in temporary directory.
133-
err = run(stringer, "-type", typeName, "-output", stringSource, "-transform", transformNameMethod, source)
139+
args := []string{"-type", typeName, "-output", stringSource, "-transform", transformNameMethod}
140+
if useTypedErrors {
141+
args = append(args, "-typederrors", "-values")
142+
}
143+
args = append(args, source)
144+
err = run(stringer, args...)
134145
if err != nil {
135146
t.Fatal(err)
136147
}

enumer.go

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@ package main
22

33
import "fmt"
44

5-
// Arguments to format are:
6-
// [1]: type name
5+
// Arguments to format are: [1]: type name [2]: complete error expression
76
const stringNameToValueMethod = `// %[1]sString retrieves an enum value from the enum constants string name.
87
// Throws an error if the param is not part of the enum.
98
func %[1]sString(s string) (%[1]s, error) {
@@ -14,20 +13,18 @@ func %[1]sString(s string) (%[1]s, error) {
1413
if val, ok := _%[1]sNameToValueMap[strings.ToLower(s)]; ok {
1514
return val, nil
1615
}
17-
return 0, fmt.Errorf("%%s does not belong to %[1]s values", s)
16+
return 0, %[2]s
1817
}
1918
`
2019

21-
// Arguments to format are:
22-
// [1]: type name
20+
// Arguments to format are: [1]: type name
2321
const stringValuesMethod = `// %[1]sValues returns all values of the enum
2422
func %[1]sValues() []%[1]s {
2523
return _%[1]sValues
2624
}
2725
`
2826

29-
// Arguments to format are:
30-
// [1]: type name
27+
// Arguments to format are: [1]: type name
3128
const stringsMethod = `// %[1]sStrings returns a slice of all String values of the enum
3229
func %[1]sStrings() []string {
3330
strs := make([]string, len(_%[1]sNames))
@@ -36,8 +33,7 @@ func %[1]sStrings() []string {
3633
}
3734
`
3835

39-
// Arguments to format are:
40-
// [1]: type name
36+
// Arguments to format are: [1]: type name
4137
const stringBelongsMethodLoop = `// IsA%[1]s returns "true" if the value is listed in the enum definition. "false" otherwise
4238
func (i %[1]s) IsA%[1]s() bool {
4339
for _, v := range _%[1]sValues {
@@ -49,17 +45,15 @@ func (i %[1]s) IsA%[1]s() bool {
4945
}
5046
`
5147

52-
// Arguments to format are:
53-
// [1]: type name
48+
// Arguments to format are: [1]: type name
5449
const stringBelongsMethodSet = `// IsA%[1]s returns "true" if the value is listed in the enum definition. "false" otherwise
5550
func (i %[1]s) IsA%[1]s() bool {
5651
_, ok := _%[1]sMap[i]
5752
return ok
5853
}
5954
`
6055

61-
// Arguments to format are:
62-
// [1]: type name
56+
// Arguments to format are: [1]: type name
6357
const altStringValuesMethod = `func (%[1]s) Values() []string {
6458
return %[1]sStrings()
6559
}
@@ -70,7 +64,7 @@ func (g *Generator) buildAltStringValuesMethod(typeName string) {
7064
g.Printf(altStringValuesMethod, typeName)
7165
}
7266

73-
func (g *Generator) buildBasicExtras(runs [][]Value, typeName string, runsThreshold int) {
67+
func (g *Generator) buildBasicExtras(runs [][]Value, typeName string, runsThreshold int, useTypedErrors bool) {
7468
// At this moment, either "g.declareIndexAndNameVars()" or "g.declareNameVars()" has been called
7569

7670
// Print the slice of values
@@ -89,7 +83,13 @@ func (g *Generator) buildBasicExtras(runs [][]Value, typeName string, runsThresh
8983
g.printNamesSlice(runs, typeName, runsThreshold)
9084

9185
// Print the basic extra methods
92-
g.Printf(stringNameToValueMethod, typeName)
86+
var errorCode string
87+
if useTypedErrors {
88+
errorCode = fmt.Sprintf(`errors.Join(enumerrs.ErrValueInvalid, fmt.Errorf("%%s does not belong to %s values", s))`, typeName)
89+
} else {
90+
errorCode = fmt.Sprintf(`fmt.Errorf("%%s does not belong to %s values", s)`, typeName)
91+
}
92+
g.Printf(stringNameToValueMethod, typeName, errorCode)
9393
g.Printf(stringValuesMethod, typeName)
9494
g.Printf(stringsMethod, typeName)
9595
if len(runs) <= runsThreshold {
@@ -143,8 +143,7 @@ func (g *Generator) printNamesSlice(runs [][]Value, typeName string, runsThresho
143143
g.Printf("}\n\n")
144144
}
145145

146-
// Arguments to format are:
147-
// [1]: type name
146+
// Arguments to format are: [1]: type name
148147
const jsonMethods = `
149148
// MarshalJSON implements the json.Marshaler interface for %[1]s
150149
func (i %[1]s) MarshalJSON() ([]byte, error) {
@@ -164,12 +163,13 @@ func (i *%[1]s) UnmarshalJSON(data []byte) error {
164163
}
165164
`
166165

167-
func (g *Generator) buildJSONMethods(runs [][]Value, typeName string, runsThreshold int) {
166+
func (g *Generator) buildJSONMethods(runs [][]Value, typeName string, runsThreshold int, useTypedErrors bool) {
167+
// For now, just use the standard template
168+
// We rely on the %[1]sString method to provide typed errors when enabled
168169
g.Printf(jsonMethods, typeName)
169170
}
170171

171-
// Arguments to format are:
172-
// [1]: type name
172+
// Arguments to format are: [1]: type name
173173
const textMethods = `
174174
// MarshalText implements the encoding.TextMarshaler interface for %[1]s
175175
func (i %[1]s) MarshalText() ([]byte, error) {
@@ -184,12 +184,13 @@ func (i *%[1]s) UnmarshalText(text []byte) error {
184184
}
185185
`
186186

187-
func (g *Generator) buildTextMethods(runs [][]Value, typeName string, runsThreshold int) {
187+
func (g *Generator) buildTextMethods(runs [][]Value, typeName string, runsThreshold int, useTypedErrors bool) {
188+
// For now, just use the standard template
189+
// We rely on the %[1]sString method to provide typed errors when enabled
188190
g.Printf(textMethods, typeName)
189191
}
190192

191-
// Arguments to format are:
192-
// [1]: type name
193+
// Arguments to format are: [1]: type name
193194
const yamlMethods = `
194195
// MarshalYAML implements a YAML Marshaler for %[1]s
195196
func (i %[1]s) MarshalYAML() (interface{}, error) {
@@ -209,6 +210,8 @@ func (i *%[1]s) UnmarshalYAML(unmarshal func(interface{}) error) error {
209210
}
210211
`
211212

212-
func (g *Generator) buildYAMLMethods(runs [][]Value, typeName string, runsThreshold int) {
213+
func (g *Generator) buildYAMLMethods(runs [][]Value, typeName string, runsThreshold int, useTypedErrors bool) {
214+
// For now, just use the standard template
215+
// We rely on the %[1]sString method to provide typed errors when enabled
213216
g.Printf(yamlMethods, typeName)
214217
}

enumerrs/errors.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package enumerrs
2+
3+
import "errors"
4+
5+
// This package defines custom error types for use in the generated code.
6+
7+
// ErrValueInvalid is returned when a value does not belong to the set of valid values for a type.
8+
var ErrValueInvalid = errors.New("the input value is not valid for the type")

0 commit comments

Comments
 (0)