diff --git a/confmap/confmap.go b/confmap/confmap.go index 9b95e4cf1bc..3040778dcd9 100644 --- a/confmap/confmap.go +++ b/confmap/confmap.go @@ -19,6 +19,7 @@ import ( "github.com/knadh/koanf/v2" encoder "go.opentelemetry.io/collector/confmap/internal/mapstructure" + "go.opentelemetry.io/collector/confmap/internal/third_party/composehook" ) const ( @@ -234,7 +235,7 @@ func decodeConfig(m *Conf, result any, errorUnused bool, skipTopLevelUnmarshaler TagName: MapstructureTag, WeaklyTypedInput: false, MatchName: caseSensitiveMatchName, - DecodeHook: mapstructure.ComposeDecodeHookFunc( + DecodeHook: composehook.ComposeDecodeHookFunc( useExpandValue(), expandNilStructPointersHookFunc(), mapstructure.StringToSliceHookFunc(","), @@ -306,6 +307,23 @@ func isStringyStructure(t reflect.Type) bool { return false } +// safeWrapDecodeHookFunc wraps a DecodeHookFuncValue to ensure fromVal is a valid `reflect.Value` +// object and therefore it is safe to call `reflect.Value` methods on fromVal. +// +// Use this only if the hook does not need to be called on untyped nil values. +// Typed nil values are safe to call and will be passed to the hook. +// See https://github.com/golang/go/issues/51649 +func safeWrapDecodeHookFunc( + f mapstructure.DecodeHookFuncValue, +) mapstructure.DecodeHookFuncValue { + return func(fromVal reflect.Value, toVal reflect.Value) (any, error) { + if !fromVal.IsValid() { + return nil, nil + } + return f(fromVal, toVal) + } +} + // When a value has been loaded from an external source via a provider, we keep both the // parsed value and the original string value. This allows us to expand the value to its // original string representation when decoding into a string field, and use the original otherwise. @@ -355,7 +373,7 @@ func useExpandValue() mapstructure.DecodeHookFuncType { // we want an unmarshaled Config to be equivalent to // Config{Thing: &SomeStruct{}} instead of Config{Thing: nil} func expandNilStructPointersHookFunc() mapstructure.DecodeHookFuncValue { - return func(from reflect.Value, to reflect.Value) (any, error) { + return safeWrapDecodeHookFunc(func(from reflect.Value, to reflect.Value) (any, error) { // ensure we are dealing with map to map comparison if from.Kind() == reflect.Map && to.Kind() == reflect.Map { toElem := to.Type().Elem() @@ -375,7 +393,7 @@ func expandNilStructPointersHookFunc() mapstructure.DecodeHookFuncValue { } } return from.Interface(), nil - } + }) } // mapKeyStringToMapKeyTextUnmarshalerHookFunc returns a DecodeHookFuncType that checks that a conversion from @@ -422,7 +440,7 @@ func mapKeyStringToMapKeyTextUnmarshalerHookFunc() mapstructure.DecodeHookFuncTy // unmarshalerEmbeddedStructsHookFunc provides a mechanism for embedded structs to define their own unmarshal logic, // by implementing the Unmarshaler interface. func unmarshalerEmbeddedStructsHookFunc() mapstructure.DecodeHookFuncValue { - return func(from reflect.Value, to reflect.Value) (any, error) { + return safeWrapDecodeHookFunc(func(from reflect.Value, to reflect.Value) (any, error) { if to.Type().Kind() != reflect.Struct { return from.Interface(), nil } @@ -455,14 +473,14 @@ func unmarshalerEmbeddedStructsHookFunc() mapstructure.DecodeHookFuncValue { } } return fromAsMap, nil - } + }) } // Provides a mechanism for individual structs to define their own unmarshal logic, // by implementing the Unmarshaler interface, unless skipTopLevelUnmarshaler is // true and the struct matches the top level object being unmarshaled. func unmarshalerHookFunc(result any, skipTopLevelUnmarshaler bool) mapstructure.DecodeHookFuncValue { - return func(from reflect.Value, to reflect.Value) (any, error) { + return safeWrapDecodeHookFunc(func(from reflect.Value, to reflect.Value) (any, error) { if !to.CanAddr() { return from.Interface(), nil } @@ -495,14 +513,14 @@ func unmarshalerHookFunc(result any, skipTopLevelUnmarshaler bool) mapstructure. } return unmarshaler, nil - } + }) } // marshalerHookFunc returns a DecodeHookFuncValue that checks structs that aren't // the original to see if they implement the Marshaler interface. func marshalerHookFunc(orig any) mapstructure.DecodeHookFuncValue { origType := reflect.TypeOf(orig) - return func(from reflect.Value, _ reflect.Value) (any, error) { + return safeWrapDecodeHookFunc(func(from reflect.Value, _ reflect.Value) (any, error) { if from.Kind() != reflect.Struct { return from.Interface(), nil } @@ -520,7 +538,7 @@ func marshalerHookFunc(orig any) mapstructure.DecodeHookFuncValue { return nil, err } return conf.ToStringMap(), nil - } + }) } // Unmarshaler interface may be implemented by types to customize their behavior when being unmarshaled from a Conf. @@ -562,7 +580,7 @@ type Marshaler interface { // 4. configuration have no `keys` field specified, the output should be default config // - for example, input is {}, then output is Config{ Keys: ["a", "b"]} func zeroSliceHookFunc() mapstructure.DecodeHookFuncValue { - return func(from reflect.Value, to reflect.Value) (any, error) { + return safeWrapDecodeHookFunc(func(from reflect.Value, to reflect.Value) (any, error) { if to.CanSet() && to.Kind() == reflect.Slice && from.Kind() == reflect.Slice { if from.IsNil() { // input slice is nil, set output slice to nil. @@ -574,7 +592,7 @@ func zeroSliceHookFunc() mapstructure.DecodeHookFuncValue { } return from.Interface(), nil - } + }) } type moduleFactory[T any, S any] interface { diff --git a/confmap/internal/third_party/composehook/compose_hook.go b/confmap/internal/third_party/composehook/compose_hook.go new file mode 100644 index 00000000000..f51050f66ed --- /dev/null +++ b/confmap/internal/third_party/composehook/compose_hook.go @@ -0,0 +1,103 @@ +// Copyright (c) 2013 Mitchell Hashimoto +// SPDX-License-Identifier: MIT +// This code is a modified version of https://github.com/go-viper/mapstructure + +package composehook // import "go.opentelemetry.io/collector/confmap/internal/third_party/composehook" + +import ( + "errors" + "reflect" + + "github.com/go-viper/mapstructure/v2" +) + +// typedDecodeHook takes a raw DecodeHookFunc (an any) and turns +// it into the proper DecodeHookFunc type, such as DecodeHookFuncType. +func typedDecodeHook(h mapstructure.DecodeHookFunc) mapstructure.DecodeHookFunc { + // Create variables here so we can reference them with the reflect pkg + var f1 mapstructure.DecodeHookFuncType + var f2 mapstructure.DecodeHookFuncKind + var f3 mapstructure.DecodeHookFuncValue + + // Fill in the variables into this interface and the rest is done + // automatically using the reflect package. + potential := []any{f3, f1, f2} + + v := reflect.ValueOf(h) + vt := v.Type() + for _, raw := range potential { + pt := reflect.ValueOf(raw).Type() + if vt.ConvertibleTo(pt) { + return v.Convert(pt).Interface() + } + } + + return nil +} + +// cachedDecodeHook takes a raw DecodeHookFunc (an any) and turns +// it into a closure to be used directly +// if the type fails to convert we return a closure always erroring to keep the previous behavior +func cachedDecodeHook(raw mapstructure.DecodeHookFunc) func(reflect.Value, reflect.Value) (any, error) { + switch f := typedDecodeHook(raw).(type) { + case mapstructure.DecodeHookFuncType: + return func(from reflect.Value, to reflect.Value) (any, error) { + // CHANGE FROM UPSTREAM: check if from is valid and return nil if not + if !from.IsValid() { + return nil, nil + } + return f(from.Type(), to.Type(), from.Interface()) + } + case mapstructure.DecodeHookFuncKind: + return func(from reflect.Value, to reflect.Value) (any, error) { + // CHANGE FROM UPSTREAM: check if from is valid and return nil if not + if !from.IsValid() { + return nil, nil + } + return f(from.Kind(), to.Kind(), from.Interface()) + } + case mapstructure.DecodeHookFuncValue: + return func(from reflect.Value, to reflect.Value) (any, error) { + return f(from, to) + } + default: + return func(reflect.Value, reflect.Value) (any, error) { + return nil, errors.New("invalid decode hook signature") + } + } +} + +// ComposeDecodeHookFunc creates a single DecodeHookFunc that +// automatically composes multiple DecodeHookFuncs. +// +// The composed funcs are called in order, with the result of the +// previous transformation. +// +// This is a copy of [mapstructure.ComposeDecodeHookFunc] but with +// validation added. +func ComposeDecodeHookFunc(fs ...mapstructure.DecodeHookFunc) mapstructure.DecodeHookFunc { + cached := make([]func(reflect.Value, reflect.Value) (any, error), 0, len(fs)) + for _, f := range fs { + cached = append(cached, cachedDecodeHook(f)) + } + return func(f reflect.Value, t reflect.Value) (any, error) { + var err error + + // CHANGE FROM UPSTREAM: check if f is valid before calling f.Interface() + var data any + if f.IsValid() { + data = f.Interface() + } + + newFrom := f + for _, c := range cached { + data, err = c(newFrom, t) + if err != nil { + return nil, err + } + newFrom = reflect.ValueOf(data) + } + + return data, nil + } +}