diff --git a/decode_hooks.go b/decode_hooks.go index 29b3bb56..02b62637 100644 --- a/decode_hooks.go +++ b/decode_hooks.go @@ -183,6 +183,26 @@ func StringToTimeDurationHookFunc() DecodeHookFunc { } } +// StringToTimeLocationHookFunc returns a DecodeHookFunc that converts +// strings to *time.Location. +func StringToTimeLocationHookFunc() DecodeHookFunc { + return func( + f reflect.Type, + t reflect.Type, + data any, + ) (any, error) { + if f.Kind() != reflect.String { + return data, nil + } + if t != reflect.TypeOf(time.Local) { + return data, nil + } + d, err := time.LoadLocation(data.(string)) + + return d, wrapTimeParseLocationError(err) + } +} + // StringToURLHookFunc returns a DecodeHookFunc that converts // strings to *url.URL. func StringToURLHookFunc() DecodeHookFunc { diff --git a/decode_hooks_test.go b/decode_hooks_test.go index c4a05cb4..bb94a48e 100644 --- a/decode_hooks_test.go +++ b/decode_hooks_test.go @@ -547,6 +547,35 @@ func TestStringToTimeDurationHookFunc(t *testing.T) { suite.Run(t) } +func TestStringToTimeLocationHookFunc(t *testing.T) { + newYork, _ := time.LoadLocation("America/New_York") + london, _ := time.LoadLocation("Europe/London") + tehran, _ := time.LoadLocation("Asia/Tehran") + shanghai, _ := time.LoadLocation("Asia/Shanghai") + + suite := decodeHookTestSuite[string, *time.Location]{ + fn: StringToTimeLocationHookFunc(), + ok: []decodeHookTestCase[string, *time.Location]{ + {"UTC", time.UTC}, + {"Local", time.Local}, + {"America/New_York", newYork}, + {"Europe/London", london}, + {"Asia/Tehran", tehran}, + {"Asia/Shanghai", shanghai}, + }, + fail: []decodeHookFailureTestCase[string, *time.Location]{ + {"UTC2"}, // Non-existent + {"5s"}, // Duration-like, not a zone + {"Europe\\London"}, // Invalid path separator + {"../etc/passwd"}, // Unsafe path + {"/etc/zoneinfo"}, // Absolute path (rejected by stdlib) + {"Asia\\Tehran"}, // Invalid Windows-style path + }, + } + + suite.Run(t) +} + func TestStringToURLHookFunc(t *testing.T) { httpURL, _ := url.Parse("http://example.com") httpsURL, _ := url.Parse("https://example.com") diff --git a/errors.go b/errors.go index 222439bd..07d31c22 100644 --- a/errors.go +++ b/errors.go @@ -230,3 +230,15 @@ func wrapTimeParseDurationError(err error) error { return err } + +func wrapTimeParseLocationError(err error) error { + if err == nil { + return nil + } + errMsg := err.Error() + if strings.Contains(errMsg, "unknown time zone") || strings.HasPrefix(errMsg, "time: unknown format") { + return fmt.Errorf("invalid time zone format: %w", err) + } + + return err +}