Skip to content
Merged
40 changes: 29 additions & 11 deletions confmap/confmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"github.com/knadh/koanf/providers/confmap"
"github.com/knadh/koanf/v2"

"go.opentelemetry.io/collector/confmap/internal"
encoder "go.opentelemetry.io/collector/confmap/internal/mapstructure"
)

Expand Down Expand Up @@ -234,7 +235,7 @@
TagName: MapstructureTag,
WeaklyTypedInput: false,
MatchName: caseSensitiveMatchName,
DecodeHook: mapstructure.ComposeDecodeHookFunc(
DecodeHook: internal.ComposeDecodeHookFunc(
useExpandValue(),
expandNilStructPointersHookFunc(),
mapstructure.StringToSliceHookFunc(","),
Expand Down Expand Up @@ -306,6 +307,23 @@
return false
}

// safeWrapDecodeHookFunc wraps a DecodeHookFuncValue to ensure that it is safe to call all
// methods in fromVal.
//
// Use this only if the hook does not need to be called on untyped nil values.
// Typed nil values are safe to call and will be passed to the hook.
// See https://github.com/golang/go/issues/51649
func safeWrapDecodeHookFunc(
f mapstructure.DecodeHookFuncValue,
) mapstructure.DecodeHookFuncValue {
return func(fromVal reflect.Value, toVal reflect.Value) (any, error) {
if !fromVal.IsValid() {
return nil, nil
}

Check warning on line 322 in confmap/confmap.go

View check run for this annotation

Codecov / codecov/patch

confmap/confmap.go#L321-L322

Added lines #L321 - L322 were not covered by tests
return f(fromVal, toVal)
}
}

// When a value has been loaded from an external source via a provider, we keep both the
// parsed value and the original string value. This allows us to expand the value to its
// original string representation when decoding into a string field, and use the original otherwise.
Expand Down Expand Up @@ -355,7 +373,7 @@
// we want an unmarshaled Config to be equivalent to
// Config{Thing: &SomeStruct{}} instead of Config{Thing: nil}
func expandNilStructPointersHookFunc() mapstructure.DecodeHookFuncValue {
return func(from reflect.Value, to reflect.Value) (any, error) {
return safeWrapDecodeHookFunc(func(from reflect.Value, to reflect.Value) (any, error) {
// ensure we are dealing with map to map comparison
if from.Kind() == reflect.Map && to.Kind() == reflect.Map {
toElem := to.Type().Elem()
Expand All @@ -375,7 +393,7 @@
}
}
return from.Interface(), nil
}
})
}

// mapKeyStringToMapKeyTextUnmarshalerHookFunc returns a DecodeHookFuncType that checks that a conversion from
Expand Down Expand Up @@ -422,7 +440,7 @@
// unmarshalerEmbeddedStructsHookFunc provides a mechanism for embedded structs to define their own unmarshal logic,
// by implementing the Unmarshaler interface.
func unmarshalerEmbeddedStructsHookFunc() mapstructure.DecodeHookFuncValue {
return func(from reflect.Value, to reflect.Value) (any, error) {
return safeWrapDecodeHookFunc(func(from reflect.Value, to reflect.Value) (any, error) {
if to.Type().Kind() != reflect.Struct {
return from.Interface(), nil
}
Expand Down Expand Up @@ -455,14 +473,14 @@
}
}
return fromAsMap, nil
}
})
}

// Provides a mechanism for individual structs to define their own unmarshal logic,
// by implementing the Unmarshaler interface, unless skipTopLevelUnmarshaler is
// true and the struct matches the top level object being unmarshaled.
func unmarshalerHookFunc(result any, skipTopLevelUnmarshaler bool) mapstructure.DecodeHookFuncValue {
return func(from reflect.Value, to reflect.Value) (any, error) {
return safeWrapDecodeHookFunc(func(from reflect.Value, to reflect.Value) (any, error) {
if !to.CanAddr() {
return from.Interface(), nil
}
Expand Down Expand Up @@ -495,14 +513,14 @@
}

return unmarshaler, nil
}
})
}

// marshalerHookFunc returns a DecodeHookFuncValue that checks structs that aren't
// the original to see if they implement the Marshaler interface.
func marshalerHookFunc(orig any) mapstructure.DecodeHookFuncValue {
origType := reflect.TypeOf(orig)
return func(from reflect.Value, _ reflect.Value) (any, error) {
return safeWrapDecodeHookFunc(func(from reflect.Value, _ reflect.Value) (any, error) {
if from.Kind() != reflect.Struct {
return from.Interface(), nil
}
Expand All @@ -520,7 +538,7 @@
return nil, err
}
return conf.ToStringMap(), nil
}
})
}

// Unmarshaler interface may be implemented by types to customize their behavior when being unmarshaled from a Conf.
Expand Down Expand Up @@ -562,7 +580,7 @@
// 4. configuration have no `keys` field specified, the output should be default config
// - for example, input is {}, then output is Config{ Keys: ["a", "b"]}
func zeroSliceHookFunc() mapstructure.DecodeHookFuncValue {
return func(from reflect.Value, to reflect.Value) (any, error) {
return safeWrapDecodeHookFunc(func(from reflect.Value, to reflect.Value) (any, error) {
if to.CanSet() && to.Kind() == reflect.Slice && from.Kind() == reflect.Slice {
if from.IsNil() {
// input slice is nil, set output slice to nil.
Expand All @@ -574,7 +592,7 @@
}

return from.Interface(), nil
}
})
}

type moduleFactory[T any, S any] interface {
Expand Down
102 changes: 102 additions & 0 deletions confmap/internal/compose_hook.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// Copyright The OpenTelemetry Authors
// SPDX-License-Identifier: Apache-2.0

package internal // import "go.opentelemetry.io/collector/confmap/internal"

import (
"errors"
"reflect"

"github.com/go-viper/mapstructure/v2"
)

// typedDecodeHook takes a raw DecodeHookFunc (an any) and turns
// it into the proper DecodeHookFunc type, such as DecodeHookFuncType.
func typedDecodeHook(h mapstructure.DecodeHookFunc) mapstructure.DecodeHookFunc {
// Create variables here so we can reference them with the reflect pkg
var f1 mapstructure.DecodeHookFuncType
var f2 mapstructure.DecodeHookFuncKind
var f3 mapstructure.DecodeHookFuncValue

// Fill in the variables into this interface and the rest is done
// automatically using the reflect package.
potential := []any{f3, f1, f2}

v := reflect.ValueOf(h)
vt := v.Type()
for _, raw := range potential {
pt := reflect.ValueOf(raw).Type()
if vt.ConvertibleTo(pt) {
return v.Convert(pt).Interface()
}
}

return nil

Check warning on line 34 in confmap/internal/compose_hook.go

View check run for this annotation

Codecov / codecov/patch

confmap/internal/compose_hook.go#L34

Added line #L34 was not covered by tests
}

// cachedDecodeHook takes a raw DecodeHookFunc (an any) and turns
// it into a closure to be used directly
// if the type fails to convert we return a closure always erroring to keep the previous behavior
func cachedDecodeHook(raw mapstructure.DecodeHookFunc) func(reflect.Value, reflect.Value) (any, error) {
switch f := typedDecodeHook(raw).(type) {
case mapstructure.DecodeHookFuncType:
return func(from reflect.Value, to reflect.Value) (any, error) {
// CHANGE FROM UPSTREAM: check if from is valid and return nil if not
if !from.IsValid() {
return nil, nil
}

Check warning on line 47 in confmap/internal/compose_hook.go

View check run for this annotation

Codecov / codecov/patch

confmap/internal/compose_hook.go#L46-L47

Added lines #L46 - L47 were not covered by tests
return f(from.Type(), to.Type(), from.Interface())
}
case mapstructure.DecodeHookFuncKind:
return func(from reflect.Value, to reflect.Value) (any, error) {
// CHANGE FROM UPSTREAM: check if from is valid and return nil if not
if !from.IsValid() {
return nil, nil
}
return f(from.Kind(), to.Kind(), from.Interface())

Check warning on line 56 in confmap/internal/compose_hook.go

View check run for this annotation

Codecov / codecov/patch

confmap/internal/compose_hook.go#L50-L56

Added lines #L50 - L56 were not covered by tests
}
case mapstructure.DecodeHookFuncValue:
return func(from reflect.Value, to reflect.Value) (any, error) {
return f(from, to)
}
default:
return func(reflect.Value, reflect.Value) (any, error) {
return nil, errors.New("invalid decode hook signature")
}

Check warning on line 65 in confmap/internal/compose_hook.go

View check run for this annotation

Codecov / codecov/patch

confmap/internal/compose_hook.go#L62-L65

Added lines #L62 - L65 were not covered by tests
}
}

// ComposeDecodeHookFunc creates a single DecodeHookFunc that
// automatically composes multiple DecodeHookFuncs.
//
// The composed funcs are called in order, with the result of the
// previous transformation.
//
// This is a copy of [mapstructure.ComposeDecodeHookFunc] but with
// validation added.
func ComposeDecodeHookFunc(fs ...mapstructure.DecodeHookFunc) mapstructure.DecodeHookFunc {
cached := make([]func(reflect.Value, reflect.Value) (any, error), 0, len(fs))
for _, f := range fs {
cached = append(cached, cachedDecodeHook(f))
}
return func(f reflect.Value, t reflect.Value) (any, error) {
var err error

// CHANGE FROM UPSTREAM: check if f is valid before calling f.Interface()
var data any
if f.IsValid() {
data = f.Interface()
}

newFrom := f
for _, c := range cached {
data, err = c(newFrom, t)
if err != nil {
return nil, err
}
newFrom = reflect.ValueOf(data)
}

return data, nil
}
}
Loading