diff --git a/featureflag/config.go b/featureflag/config.go deleted file mode 100644 index 26a601c..0000000 --- a/featureflag/config.go +++ /dev/null @@ -1,24 +0,0 @@ -package featureflag - -import ( - "github.com/netlify/netlify-commons/util" - "gopkg.in/launchdarkly/go-server-sdk.v5/interfaces" -) - -type Config struct { - Key string `json:"key" yaml:"key"` - RequestTimeout util.Duration `json:"request_timeout" yaml:"request_timeout" mapstructure:"request_timeout" split_words:"true" default:"5s"` - Enabled bool `json:"enabled" yaml:"enabled" default:"false"` - - updateProcessorFactory interfaces.DataSourceFactory - - // Drop telemetry events (not needed in local-dev/CI environments) - DisableEvents bool `json:"disable_events" yaml:"disable_events" mapstructure:"disable_events" split_words:"true"` - - // Set when using the Launch Darkly Relay proxy - RelayHost string `json:"relay_host" yaml:"relay_host" mapstructure:"relay_host" split_words:"true"` - - // DefaultUserAttrs are custom LaunchDarkly user attributes that are added to every - // feature flag check - DefaultUserAttrs map[string]string `json:"default_user_attrs" yaml:"default_user_attrs"` -} diff --git a/featureflag/featureflag.go b/featureflag/featureflag.go deleted file mode 100644 index fa9f5d1..0000000 --- a/featureflag/featureflag.go +++ /dev/null @@ -1,183 +0,0 @@ -package featureflag - -import ( - "io/ioutil" - - "github.com/sirupsen/logrus" - - "gopkg.in/launchdarkly/go-sdk-common.v2/ldlog" - "gopkg.in/launchdarkly/go-sdk-common.v2/lduser" - "gopkg.in/launchdarkly/go-sdk-common.v2/ldvalue" - ld "gopkg.in/launchdarkly/go-server-sdk.v5" - "gopkg.in/launchdarkly/go-server-sdk.v5/interfaces" - "gopkg.in/launchdarkly/go-server-sdk.v5/interfaces/flagstate" - "gopkg.in/launchdarkly/go-server-sdk.v5/ldcomponents" -) - -type Client interface { - Enabled(key, userID string, attrs ...Attr) bool - EnabledUser(key string, user lduser.User) bool - - Variation(key, defaultVal, userID string, attrs ...Attr) string - VariationUser(key string, defaultVal string, user lduser.User) string - - Int(key string, defaultVal int, userID string, attrs ...Attr) int - IntUser(key string, defaultVal int, user lduser.User) int - - AllEnabledFlags(key string) []string - AllEnabledFlagsUser(key string, user lduser.User) []string -} - -type ldClient struct { - *ld.LDClient - log logrus.FieldLogger - defaultAttrs []Attr -} - -var _ Client = &ldClient{} - -func NewClient(cfg *Config, logger logrus.FieldLogger) (Client, error) { - config := ld.Config{} - - if !cfg.Enabled { - config.Offline = true - } - - if cfg.updateProcessorFactory != nil { - config.DataSource = cfg.updateProcessorFactory - config.Events = ldcomponents.NoEvents() - } - - config.Logging = configureLogger(logger) - - if cfg.RelayHost != "" { - config.ServiceEndpoints = ldcomponents.RelayProxyEndpoints(cfg.RelayHost) - } - - if cfg.DisableEvents { - config.Events = ldcomponents.NoEvents() - } - - inner, err := ld.MakeCustomClient(cfg.Key, config, cfg.RequestTimeout.Duration) - if err != nil { - logger.WithError(err).Error("Unable to construct LD client") - } - - var defaultAttrs []Attr - for k, v := range cfg.DefaultUserAttrs { - defaultAttrs = append(defaultAttrs, StringAttr(k, v)) - } - return &ldClient{inner, logger, defaultAttrs}, err -} - -func (c *ldClient) Enabled(key string, userID string, attrs ...Attr) bool { - return c.EnabledUser(key, c.userWithAttrs(userID, attrs)) -} - -func (c *ldClient) EnabledUser(key string, user lduser.User) bool { - res, err := c.BoolVariation(key, user, false) - if err != nil { - c.log.WithError(err).WithField("key", key).Error("Failed to load feature flag") - } - return res -} - -func (c *ldClient) Variation(key, defaultVal, userID string, attrs ...Attr) string { - return c.VariationUser(key, defaultVal, c.userWithAttrs(userID, attrs)) -} - -func (c *ldClient) VariationUser(key string, defaultVal string, user lduser.User) string { - res, err := c.StringVariation(key, user, defaultVal) - if err != nil { - c.log.WithError(err).WithField("key", key).Error("Failed to load feature flag") - } - return res -} - -func (c *ldClient) Int(key string, defaultValue int, userID string, attrs ...Attr) int { - return c.IntUser(key, defaultValue, c.userWithAttrs(userID, attrs)) -} - -func (c *ldClient) IntUser(key string, defaultVal int, user lduser.User) int { - res, err := c.IntVariation(key, user, defaultVal) - if err != nil { - c.log.WithError(err).WithField("key", key).Error("Failed to load feature flag") - } - // DefaultValue will be returned if IntVariation returns an error - return res -} - -func (c *ldClient) AllEnabledFlags(key string) []string { - return c.AllEnabledFlagsUser(key, lduser.NewUser(key)) -} - -func (c *ldClient) AllEnabledFlagsUser(key string, user lduser.User) []string { - res := c.AllFlagsState(user, flagstate.OptionDetailsOnlyForTrackedFlags()) - flagMap := res.ToValuesMap() - - var flags []string - for flag, value := range flagMap { - if value.BoolValue() { - flags = append(flags, flag) - } - } - - return flags -} - -func (c *ldClient) userWithAttrs(id string, attrs []Attr) lduser.User { - b := lduser.NewUserBuilder(id) - for _, attr := range c.defaultAttrs { - b.Custom(attr.Name, attr.Value) - } - for _, attr := range attrs { - b.Custom(attr.Name, attr.Value) - } - return b.Build() -} - -type Attr struct { - Name string - Value ldvalue.Value -} - -func StringAttr(name, value string) Attr { - return Attr{Name: name, Value: ldvalue.String(value)} -} - -func configureLogger(log logrus.FieldLogger) interfaces.LoggingConfigurationFactory { - if log == nil { - l := logrus.New() - l.SetOutput(ioutil.Discard) - log = l - } - log = log.WithField("component", "launch_darkly") - - return &logCreator{log: log} -} - -type logCreator struct { - log logrus.FieldLogger -} - -func (c *logCreator) CreateLoggingConfiguration(b interfaces.BasicConfiguration) (interfaces.LoggingConfiguration, error) { - logger := ldlog.NewDefaultLoggers() - logger.SetBaseLoggerForLevel(ldlog.Debug, &wrapLog{c.log.Debugln, c.log.Debugf}) - logger.SetBaseLoggerForLevel(ldlog.Info, &wrapLog{c.log.Infoln, c.log.Infof}) - logger.SetBaseLoggerForLevel(ldlog.Warn, &wrapLog{c.log.Warnln, c.log.Warnf}) - logger.SetBaseLoggerForLevel(ldlog.Error, &wrapLog{c.log.Errorln, c.log.Errorf}) - return ldcomponents.Logging().Loggers(logger).CreateLoggingConfiguration(b) -} - -type wrapLog struct { - println func(values ...interface{}) - printf func(format string, values ...interface{}) -} - -func (l *wrapLog) Println(values ...interface{}) { - l.println(values...) -} - -func (l *wrapLog) Printf(format string, values ...interface{}) { - l.printf(format, values...) -} diff --git a/featureflag/featureflag_test.go b/featureflag/featureflag_test.go deleted file mode 100644 index 0285b8d..0000000 --- a/featureflag/featureflag_test.go +++ /dev/null @@ -1,86 +0,0 @@ -package featureflag - -import ( - "bytes" - "testing" - "time" - - "github.com/netlify/netlify-commons/util" - "github.com/sirupsen/logrus" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "gopkg.in/launchdarkly/go-server-sdk.v5/ldfiledata" -) - -func TestOfflineClient(t *testing.T) { - cfg := Config{ - Key: "ABCD", - RequestTimeout: util.Duration{time.Second}, - Enabled: false, - } - client, err := NewClient(&cfg, nil) - require.NoError(t, err) - - require.False(t, client.Enabled("notset", "12345")) - require.Equal(t, "foobar", client.Variation("notset", "foobar", "12345")) - require.Equal(t, 3, client.Int("noset", 3, "12345")) -} - -func TestMockClient(t *testing.T) { - mock := MockClient{ - BoolVars: map[string]bool{ - "FOO": true, - "BAR": false, - }, - StringVars: map[string]string{ - "FOO": "BAR", - "BLAH": "FOOBAR", - }, - IntVars: map[string]int{ - "FOO": 4, - "BLAH": 7, - }, - } - - require.True(t, mock.Enabled("FOO", "12345")) - require.False(t, mock.Enabled("BAR", "12345")) - require.False(t, mock.Enabled("NOTSET", "12345")) - - require.Equal(t, "BAR", mock.Variation("FOO", "DFLT", "12345")) - require.Equal(t, "DFLT", mock.Variation("FOOBAR", "DFLT", "12345")) - - require.Equal(t, 4, mock.Int("FOO", 2, "12345")) - require.Equal(t, 2, mock.Int("FOOBAR", 2, "12345")) -} - -func TestAllEnabledFlags(t *testing.T) { - fileSource := ldfiledata.DataSource().FilePaths("./fixtures/flags.yml") - cfg := Config{ - Key: "ABCD", - RequestTimeout: util.Duration{time.Second}, - Enabled: true, - updateProcessorFactory: fileSource, - } - client, err := NewClient(&cfg, nil) - require.NoError(t, err) - - flags := client.AllEnabledFlags("userid") - - require.Equal(t, []string{"my-boolean-flag-key"}, flags) -} - -func TestLogging(t *testing.T) { - cfg := Config{ - Key: "ABCD", - RequestTimeout: util.Duration{time.Second}, - Enabled: false, - } - - logBuf := new(bytes.Buffer) - log := logrus.New() - log.Out = logBuf - - _, err := NewClient(&cfg, log.WithField("component", "launch_darkly")) - require.NoError(t, err) - assert.NotEmpty(t, logBuf.Bytes()) -} diff --git a/featureflag/fixtures/flags.yml b/featureflag/fixtures/flags.yml deleted file mode 100644 index 78259db..0000000 --- a/featureflag/fixtures/flags.yml +++ /dev/null @@ -1,5 +0,0 @@ -flagValues: - my-string-flag-key: "value-1" - my-boolean-flag-key: true - my-boolean-off-flag-key: false - my-integer-flag-key: 3 diff --git a/featureflag/global.go b/featureflag/global.go deleted file mode 100644 index 4eeafdb..0000000 --- a/featureflag/global.go +++ /dev/null @@ -1,47 +0,0 @@ -package featureflag - -import ( - "github.com/sirupsen/logrus" - "go.uber.org/atomic" - "unsafe" -) - -// See https://blog.dubbelboer.com/2015/08/23/rwmutex-vs-atomicvalue-vs-unsafepointer.html -var ( - defaultClient Client = MockClient{} - globalClient = atomic.NewUnsafePointer(unsafe.Pointer(&defaultClient)) -) - -func SetGlobalClient(client Client) { - if client == nil { - return - } - globalClient.Store(unsafe.Pointer(&client)) -} - -func GetGlobalClient() Client { - c := (*Client)(globalClient.Load()) - return *c -} - -// Init will initialize global client with a launch darkly client -func Init(conf Config, log logrus.FieldLogger) error { - ldClient, err := NewClient(&conf, log) - if err != nil { - return err - } - SetGlobalClient(ldClient) - return nil -} - -func Enabled(key, userID string, attrs ...Attr) bool { - return GetGlobalClient().Enabled(key, userID, attrs...) -} - -func Variation(key, defaultVal, userID string, attrs ...Attr) string { - return GetGlobalClient().Variation(key, defaultVal, userID, attrs...) -} - -func Int(key string, defaultVal int, userID string, attrs ...Attr) int { - return GetGlobalClient().Int(key, defaultVal, userID, attrs...) -} diff --git a/featureflag/global_test.go b/featureflag/global_test.go deleted file mode 100644 index e664158..0000000 --- a/featureflag/global_test.go +++ /dev/null @@ -1,16 +0,0 @@ -package featureflag - -import ( - "github.com/stretchr/testify/require" - "testing" -) - -func TestGlobalAccess(t *testing.T) { - // initial value should be default - require.Equal(t, defaultClient, GetGlobalClient()) - - // setting new global should be reflected - n := &ldClient{} - SetGlobalClient(n) - require.Equal(t, n, GetGlobalClient()) -} diff --git a/featureflag/mock.go b/featureflag/mock.go deleted file mode 100644 index 3b4d158..0000000 --- a/featureflag/mock.go +++ /dev/null @@ -1,59 +0,0 @@ -package featureflag - -import ( - "gopkg.in/launchdarkly/go-sdk-common.v2/lduser" -) - -type MockClient struct { - BoolVars map[string]bool - StringVars map[string]string - IntVars map[string]int -} - -var _ Client = MockClient{} - -func (c MockClient) Enabled(key, userID string, _ ...Attr) bool { - return c.EnabledUser(key, lduser.NewUser(userID)) -} - -func (c MockClient) EnabledUser(key string, _ lduser.User) bool { - return c.BoolVars[key] -} - -func (c MockClient) Variation(key string, defaultVal string, userID string, _ ...Attr) string { - return c.VariationUser(key, defaultVal, lduser.NewUser(userID)) -} - -func (c MockClient) VariationUser(key string, defaultVal string, _ lduser.User) string { - res, ok := c.StringVars[key] - if !ok { - return defaultVal - } - return res -} - -func (c MockClient) Int(key string, defaultVal int, userID string, _ ...Attr) int { - return c.IntUser(key, defaultVal, lduser.NewUser(userID)) -} - -func (c MockClient) IntUser(key string, defaultVal int, _ lduser.User) int { - res, ok := c.IntVars[key] - if !ok { - return defaultVal - } - return res -} - -func (c MockClient) AllEnabledFlags(key string) []string { - return c.AllEnabledFlagsUser(key, lduser.NewUser(key)) -} - -func (c MockClient) AllEnabledFlagsUser(key string, _ lduser.User) []string { - var res []string - for key, value := range c.BoolVars { - if value { - res = append(res, key) - } - } - return res -} diff --git a/instrument/instrument_test.go b/instrument/instrument_test.go index 5f80939..21eb7bc 100644 --- a/instrument/instrument_test.go +++ b/instrument/instrument_test.go @@ -4,7 +4,8 @@ import ( "reflect" "testing" - "github.com/netlify/netlify-commons/testutil" + "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gopkg.in/segmentio/analytics-go.v3" @@ -22,7 +23,7 @@ func TestLogOnlyClient(t *testing.T) { } func TestMockClient(t *testing.T) { - log := testutil.TL(t) + log := logrus.New() mock := MockClient{log} require.NoError(t, mock.Identify("myuser", analytics.NewTraits().SetName("My User"))) @@ -33,7 +34,7 @@ func TestLogging(t *testing.T) { Key: "ABCD", } - log, hook := testutil.TestLogger(t) + log, hook := test.NewNullLogger() client, err := NewClient(&cfg, log.WithField("component", "segment")) require.NoError(t, err) diff --git a/messaging/config.go b/messaging/config.go index d7284f4..a96c9c7 100644 --- a/messaging/config.go +++ b/messaging/config.go @@ -5,14 +5,10 @@ import ( "strings" "time" - "github.com/netlify/netlify-commons/nconf" - - "github.com/pkg/errors" - "github.com/nats-io/stan.go" "github.com/nats-io/stan.go/pb" - "github.com/netlify/netlify-commons/discovery" + "github.com/pkg/errors" "github.com/sirupsen/logrus" ) @@ -30,10 +26,10 @@ type NatsAuth struct { } type NatsConfig struct { - TLS *nconf.TLSConfig `mapstructure:"tls_conf"` - DiscoveryName string `mapstructure:"discovery_name" split_words:"true"` - Servers []string `mapstructure:"servers"` - Auth NatsAuth `mapstructure:"auth"` + TLS *TLSConfig `mapstructure:"tls_conf"` + DiscoveryName string `mapstructure:"discovery_name" split_words:"true"` + Servers []string `mapstructure:"servers"` + Auth NatsAuth `mapstructure:"auth"` // for streaming ClusterID string `mapstructure:"cluster_id" split_words:"true"` diff --git a/nconf/tls.go b/messaging/tls.go similarity index 99% rename from nconf/tls.go rename to messaging/tls.go index 025649e..c017f61 100644 --- a/nconf/tls.go +++ b/messaging/tls.go @@ -1,4 +1,4 @@ -package nconf +package messaging import ( "crypto/tls" diff --git a/metriks/config.go b/metriks/config.go deleted file mode 100644 index e64bb6c..0000000 --- a/metriks/config.go +++ /dev/null @@ -1,9 +0,0 @@ -package metriks - -type Config struct { - Enabled bool - Host string `default:"localhost"` - Port int `default:"8125"` - - Tags map[string]string -} diff --git a/metriks/db_stats.go b/metriks/db_stats.go deleted file mode 100644 index e4f9d40..0000000 --- a/metriks/db_stats.go +++ /dev/null @@ -1,69 +0,0 @@ -package metriks - -import ( - "database/sql" - "fmt" - "time" - - "github.com/armon/go-metrics" - - "github.com/netlify/netlify-commons/util" -) - -type DBStats interface { - Start() - Stop() -} - -type dBStats struct { - db *sql.DB - name string - labels []metrics.Label - sched util.ScheduledExecutor -} - -const defaultTickTime = 2 * time.Second - -// NewDBStats returns a managed object that when Start() is invoked, will periodically -// report the stats from the passed DB object. -// -// Stop() should be called before the DB is closed. -func NewDBStats(db *sql.DB, name string, labels []metrics.Label) DBStats { - dbstats := &dBStats{ - db: db, - name: fmt.Sprintf("dbstats.%s", name), - labels: labels, - } - - dbstats.sched = util.NewScheduledExecutor(defaultTickTime, dbstats.emitStats) - - return dbstats -} - -func (d *dBStats) Start() { - d.sched.Start() -} - -func (d *dBStats) Stop() { - d.sched.Stop() -} - -func (d *dBStats) emitStats() { - stats := d.db.Stats() - - d.emitStat("MaxOpenConnections", float32(stats.MaxOpenConnections)) - - d.emitStat("OpenConnections", float32(stats.OpenConnections)) - d.emitStat("InUse", float32(stats.InUse)) - d.emitStat("Idle", float32(stats.Idle)) - - d.emitStat("WaitCount", float32(stats.WaitCount)) - d.emitStat("WaitDuration_us", float32(stats.WaitDuration.Nanoseconds()/1000)) - d.emitStat("MaxIdleClosed", float32(stats.MaxIdleClosed)) - d.emitStat("MaxLifetimeClosed", float32(stats.MaxLifetimeClosed)) -} - -func (d *dBStats) emitStat(statName string, value float32) { - labels := append(d.labels, L("db_stat", statName)) - SampleLabels(d.name, labels, value) -} diff --git a/metriks/distribution.go b/metriks/distribution.go deleted file mode 100644 index 9792b2a..0000000 --- a/metriks/distribution.go +++ /dev/null @@ -1,50 +0,0 @@ -package metriks - -import ( - "fmt" - "strings" - - "github.com/DataDog/datadog-go/statsd" - "github.com/armon/go-metrics" -) - -var ( - distributionFunc = func(name string, value float64, tags ...metrics.Label) {} - distributionErrorHandler = func(_ error) {} -) - -func initDistribution(url string, serviceName string, permTags []string) error { - statsd, err := statsd.New(url) - if err != nil { - return err - } - - distributionFunc = func(name string, value float64, tags ...metrics.Label) { - ddtags := make([]string, 0, len(permTags)+len(tags)) - ddtags = append(ddtags, permTags...) - for _, v := range tags { - ddtags = append(ddtags, fmt.Sprintf("%s:%s", v.Name, v.Value)) - } - - name = fmt.Sprintf("%s.%s", serviceName, strings.ReplaceAll(name, "-", "_")) - - err := statsd.Distribution(name, float64(value), ddtags, 1) - if err != nil { - distributionErrorHandler(err) - } - } - - return nil -} - -// SetDistributionErrorHandler will set the global error handler. It will be invoked -// anytime that the statsd call produces an error -func SetDistributionErrorHandler(f func(error)) { - distributionErrorHandler = f -} - -// Distribution will report the value as a distribution metric to datadog. -// it only makes sense when you're using the datadog sink -func Distribution(name string, value float64, tags ...metrics.Label) { - distributionFunc(name, value, tags...) -} diff --git a/metriks/distribution_test.go b/metriks/distribution_test.go deleted file mode 100644 index a5ba7ec..0000000 --- a/metriks/distribution_test.go +++ /dev/null @@ -1,76 +0,0 @@ -package metriks - -import ( - "bytes" - "net" - "sync" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestDistribution(t *testing.T) { - pc, err := net.ListenPacket("udp", "") - require.NoError(t, err) - defer pc.Close() - - msgs := make(chan []byte) - go func() { - buf := make([]byte, 1024) - _, _, err := pc.ReadFrom(buf) - require.NoError(t, err) - msgs <- buf - close(msgs) - }() - - require.NoError(t, initDistribution(pc.LocalAddr().String(), "testing", []string{"a:1"})) - SetDistributionErrorHandler(func(err error) { - assert.NoError(t, err) - }) - - Distribution("some_metric", 12.0, L("b", "c")) - - buf := <-msgs - assert.Equal(t, "testing.some_metric:12|d|#a:1,b:c\n", string(bytes.Trim(buf, "\x00"))) -} - -func TestDistributionRace(t *testing.T) { - pc, err := net.ListenPacket("udp", "") - require.NoError(t, err) - defer pc.Close() - - go func() { - for { - buf := make([]byte, 1024) - pc.ReadFrom(buf) - } - }() - - // set cap so concurrent callers of Distribution overwrite the same space - permTags := make([]string, 1, 8) - permTags[0] = "a:1" - - require.NoError(t, initDistribution(pc.LocalAddr().String(), "testing", permTags)) - SetDistributionErrorHandler(func(err error) { - assert.NoError(t, err) - }) - - work := make(chan struct{}) - var wg sync.WaitGroup - for n := 0; n < 1024; n++ { - wg.Add(1) - go func() { - for range work { - Distribution("some_metric", 12.0, L("b", "c")) - } - wg.Done() - }() - } - - for n := 0; n < 100_000; n++ { - work <- struct{}{} - } - close(work) - wg.Wait() -} diff --git a/metriks/gauge.go b/metriks/gauge.go deleted file mode 100644 index ebac1ed..0000000 --- a/metriks/gauge.go +++ /dev/null @@ -1,109 +0,0 @@ -package metriks - -import ( - "context" - "sync/atomic" - "time" - - "github.com/armon/go-metrics" - "github.com/netlify/netlify-commons/util" -) - -const ( - defaultGaugeDuration = time.Second * 5 -) - -// PersistentGauge will report on an interval the value to the metrics collector. -// -type PersistentGauge struct { - name string - value int32 - tags []metrics.Label - - ticker *time.Ticker - cancel context.CancelFunc - dur time.Duration - donec chan struct{} -} - -// Set will replace the value with a new one, it returns the old value -func (g *PersistentGauge) Set(value int32) int32 { - return atomic.SwapInt32(&g.value, value) -} - -// Inc will +1 to the current value and return the new value -func (g *PersistentGauge) Inc() int32 { - return atomic.AddInt32(&g.value, 1) -} - -// Dec will -1 to the current value and return the new value -func (g *PersistentGauge) Dec() int32 { - return atomic.AddInt32(&g.value, -1) -} - -func (g *PersistentGauge) report(v int32) { - Gauge(g.name, float32(v), g.tags...) - g.ticker.Reset(g.dur) -} - -func (g *PersistentGauge) start(ctx context.Context) { - for { - select { - case <-ctx.Done(): - close(g.donec) - return - case <-g.ticker.C: - g.report(atomic.LoadInt32(&g.value)) - } - } -} - -// Stop will make the gauge stop reporting. Any calls to Inc/Set/Dec will still report -// to the metrics collector -func (g *PersistentGauge) Stop() { - g.cancel() - <-g.donec -} - -// NewPersistentGauge will create and start a PersistentGauge that reports the current value every 10s -func NewPersistentGauge(name string, tags ...metrics.Label) *PersistentGauge { - return NewPersistentGaugeWithDuration(name, defaultGaugeDuration, tags...) -} - -// NewPersistentGaugeWithDuration will create and start a PersistentGauge that reports the current value every period -func NewPersistentGaugeWithDuration(name string, dur time.Duration, tags ...metrics.Label) *PersistentGauge { - ctx, cancel := context.WithCancel(context.Background()) - g := PersistentGauge{ - name: name, - tags: tags, - ticker: time.NewTicker(dur), - cancel: cancel, - dur: dur, - donec: make(chan struct{}), - } - go g.start(ctx) - return &g -} - -// ScheduledGauge will call the provided method after a duration -// it will then report that value to the metrics system -type ScheduledGauge struct { - util.ScheduledExecutor -} - -// NewScheduledGauge will create an start a ScheduledGauge that reports the value every 10s -func NewScheduledGauge(name string, cb func() int32, tags ...metrics.Label) ScheduledGauge { - return NewScheduledGaugeWithDuration(name, defaultGaugeDuration, cb, tags...) -} - -// NewScheduledGaugeWithDuration will create an start a ScheduledGauge that reports the value every period -func NewScheduledGaugeWithDuration(name string, dur time.Duration, cb func() int32, tags ...metrics.Label) ScheduledGauge { - g := ScheduledGauge{ - ScheduledExecutor: util.NewScheduledExecutor(dur, func() { - v := cb() - Gauge(name, float32(v), tags...) - }), - } - g.Start() - return g -} diff --git a/metriks/gauge_test.go b/metriks/gauge_test.go deleted file mode 100644 index b3f81b3..0000000 --- a/metriks/gauge_test.go +++ /dev/null @@ -1,109 +0,0 @@ -package metriks - -import ( - "bytes" - "net" - "strings" - "testing" - "time" - - "github.com/armon/go-metrics" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestPersistentGauge(t *testing.T) { - g := NewPersistentGaugeWithDuration("some_gauge", time.Second, L("a", "b")) - defer g.Stop() - - res := setupStatsDSink(t) - - assert.EqualValues(t, 1, g.Inc()) - assert.EqualValues(t, 0, g.Dec()) - assert.EqualValues(t, 0, g.Set(10)) - - expectedValues := []string{ - // this value should be reported after an interval - "test.some_gauge.b:10.000000|g", - } - - for i := 0; i < len(expectedValues); i++ { - select { - case msg := <-res: - assert.Equal(t, expectedValues[i], msg) - case <-time.After(time.Second * 10): - assert.Fail(t, "failed to get a metric in time") - } - } -} - -func TestPersistentGaugeRace(t *testing.T) { - for n := 0; n < 10; n++ { - g := NewPersistentGaugeWithDuration("some_gauge", time.Second, L("a", "b")) - g.Stop() - } -} - -func TestScheduledGauge(t *testing.T) { - var callCount int32 - cb := func() int32 { - callCount++ - return callCount - } - - g := NewScheduledGaugeWithDuration("some_gauge", time.Second, cb, L("a", "b")) - defer g.Stop() - - res := setupStatsDSink(t) - - expectedValues := []string{ - "test.some_gauge.b:1.000000|g", - "test.some_gauge.b:2.000000|g", - } - for i := 0; i < 2; i++ { - select { - case msg := <-res: - assert.Equal(t, expectedValues[i], msg) - case <-time.After(time.Second * 20): - require.Fail(t, "failed to get a metric in time") - } - } -} - -func setupStatsDSink(t *testing.T) <-chan string { - pc, err := net.ListenPacket("udp", ":10000") - require.NoError(t, err) - t.Cleanup(func() { require.NoError(t, pc.Close()) }) - - sink, err := metrics.NewStatsdSink(pc.LocalAddr().String()) - require.NoError(t, err) - t.Cleanup(sink.Shutdown) - - cfg := metrics.DefaultConfig("test") - cfg.EnableHostname = false - cfg.EnableRuntimeMetrics = false - _, err = metrics.NewGlobal(cfg, sink) - require.NoError(t, err) - - return readValues(t, pc) -} - -func readValues(t *testing.T, pc net.PacketConn) <-chan string { - res := make(chan string) - go func() { - for { - buf := make([]byte, 512) - _, _, err := pc.ReadFrom(buf) - if err != nil { - close(res) - return - } - for _, p := range strings.Split(string(bytes.Trim(buf, "\x00")), "\n") { - if p != "" { - res <- p - } - } - } - }() - return res -} diff --git a/metriks/metriks.go b/metriks/metriks.go deleted file mode 100644 index dd593a2..0000000 --- a/metriks/metriks.go +++ /dev/null @@ -1,231 +0,0 @@ -package metriks - -import ( - "fmt" - "net/url" - "os" - "strings" - "time" - - "github.com/armon/go-metrics" - "github.com/armon/go-metrics/datadog" - "github.com/pkg/errors" -) - -const timerGranularity = time.Millisecond - -// Init will initialize the internal metrics system with a Datadog statsd sink -func Init(serviceName string, conf Config) error { - return InitTags(serviceName, conf, nil) -} - -// InitTags behaves like Init but allows appending extra tags -func InitTags(serviceName string, conf Config, extraTags []string) error { - if !conf.Enabled { - return nil - } - - sink, err := createDatadogSink(conf.StatsdAddr(), "", serviceName, conf.Tags, extraTags) - if err != nil { - return err - } - - return InitWithSink(serviceName, sink) -} - -// InitWithURL will initialize using a URL to identify the sink type -// -// Examples: -// -// InitWithURL("api", "datadog://187.32.21.12:8125/?hostname=foo.com&tag=env:production") -// -// InitWithURL("api", "discard://nothing") -// -// InitWithURL("api", "inmem://discarded/?interval=10s&duration=30s") -// -func InitWithURL(serviceName string, endpoint string) (metrics.MetricSink, error) { - u, err := url.Parse(endpoint) - if err != nil { - return nil, errors.Wrap(err, "invalid endpoint format") - } - - hostname := u.Query().Get("hostname") - if hostname == "" { - h, _ := os.Hostname() - hostname = h - } - - var sink metrics.MetricSink - switch u.Scheme { - case "datadog": - sink, err = createDatadogSink(u.Host, hostname, serviceName, map[string]string{}, u.Query()["tag"]) - case "discard", "": - sink = &metrics.BlackholeSink{} - default: - sink, err = metrics.NewMetricSinkFromURL(endpoint) - } - - if err != nil { - return nil, errors.Wrap(err, "error creating sink") - } - - err = InitWithSink(serviceName, sink) - return sink, err -} - -// InitWithSink initializes the internal metrics system with custom sink -func InitWithSink(serviceName string, sink metrics.MetricSink) error { - c := metrics.DefaultConfig(serviceName) - c.EnableHostname = false - c.EnableHostnameLabel = false - c.EnableServiceLabel = false - c.TimerGranularity = timerGranularity - - if _, err := metrics.NewGlobal(c, sink); err != nil { - return err - } - return nil -} - -func createDatadogSink(url, hostname, serviceName string, tags map[string]string, extraTags []string) (metrics.MetricSink, error) { - sink, err := datadog.NewDogStatsdSink(url, hostname) - if err != nil { - return nil, err - } - - // Don't override a setting from the user - var serviceTagSet bool - var ddTags []string - for k, v := range tags { - if strings.ToLower(k) == "service" { - serviceTagSet = true - } - - ddTags = append(ddTags, fmt.Sprintf("%s:%s", k, v)) - } - - for _, t := range extraTags { - sp := strings.Split(t, ":") - if len(sp) > 0 && strings.ToLower(sp[0]) == "service" { - serviceTagSet = true - } - - ddTags = append(ddTags, t) - } - - if !serviceTagSet { - ddTags = append(ddTags, fmt.Sprintf("service:%s", serviceName)) - } - - sink.SetTags(ddTags) - if err := initDistribution(url, serviceName, ddTags); err != nil { - return nil, errors.Wrap(err, "failed to initialize the datadog statsd client") - } - return sink, nil -} - -// -// Some simpler wrappers around go-metrics -// - -// Labels builds a dynamic list of labels -func Labels(labels ...metrics.Label) []metrics.Label { - return labels -} - -// L returns a single label, kept short for conciseness -func L(name string, value string) metrics.Label { - return metrics.Label{ - Name: name, - Value: value, - } -} - -// Inc increments a simple counter -// -// Example: -// -// metriks.Inc("publisher.errors", 1) -// -func Inc(name string, val int64, labels ...metrics.Label) { - if len(labels) > 0 { - metrics.IncrCounterWithLabels([]string{name}, float32(val), labels) - } else { - metrics.IncrCounter([]string{name}, float32(val)) - } -} - -// IncLabels increments a counter with additional labels -// -// Example: -// -// metriks.IncLabels("publisher.errors", metriks.Labels(metriks.L("status_class", "4xx")), 1) -// -func IncLabels(name string, labels []metrics.Label, val int64) { - Inc(name, val, labels...) -} - -// MeasureSince records the time from start until the invocation of the function -// It is usually used with `defer` to record time of a function. -// -// Example: -// -// func getRows() ([]Row) { -// defer metriks.MeasureSince("publisher-get-rows.time", time.Now()) -// -// query := "SELECT * FROM publisher" -// return db.Execute(query) -// } -// -func MeasureSince(name string, start time.Time, labels ...metrics.Label) { - if len(labels) > 0 { - metrics.MeasureSinceWithLabels([]string{name}, start, labels) - } else { - metrics.MeasureSince([]string{name}, start) - } -} - -// MeasureSinceLabels is the same as MeasureSince, but with additional labels -func MeasureSinceLabels(name string, labels []metrics.Label, start time.Time) { - MeasureSince(name, start, labels...) -} - -// Sample records a float32 sample as part of a histogram. This will get histogram -// distribution metrics -// -// Example: -// -// metriks.Sample("publisher-payload-size", float32(len(payload))) -// -func Sample(name string, val float32, labels ...metrics.Label) { - if len(labels) > 0 { - metrics.AddSampleWithLabels([]string{name}, val, labels) - } else { - metrics.AddSample([]string{name}, val) - } -} - -// SampleLabels is the same as Sample but with additional labels -func SampleLabels(name string, labels []metrics.Label, val float32) { - Sample(name, val, labels...) -} - -// Gauge is used to report a single float32 value. It is most often used during a -// periodic update or timer to report the current size of a queue or how many -// connections are currently connected. -func Gauge(name string, val float32, labels ...metrics.Label) { - if len(labels) > 0 { - metrics.SetGaugeWithLabels([]string{name}, val, labels) - } else { - metrics.SetGauge([]string{name}, val) - } -} - -// GaugeLabels is the same as Gauge but with additional labels -func GaugeLabels(name string, labels []metrics.Label, val float32) { - Gauge(name, val, labels...) -} - -func (conf Config) StatsdAddr() string { - return fmt.Sprintf("%s:%d", conf.Host, conf.Port) -} diff --git a/metriks/metriks_test.go b/metriks/metriks_test.go deleted file mode 100644 index 3e891c5..0000000 --- a/metriks/metriks_test.go +++ /dev/null @@ -1,104 +0,0 @@ -package metriks - -import ( - "bytes" - "fmt" - "testing" - - "github.com/armon/go-metrics" - "github.com/armon/go-metrics/datadog" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "golang.org/x/net/nettest" -) - -func TestMetriksInit(t *testing.T) { - err := InitWithSink("foo", &metrics.BlackholeSink{}) - require.NoError(t, err) - - config := Config{ - Host: "127.0.0.1", - Port: 8125, - Tags: nil, - } - err = Init("foo", config) - require.NoError(t, err) -} - -func TestDatadogSink(t *testing.T) { - l, err := nettest.NewLocalPacketListener("udp") - require.NoError(t, err) - defer l.Close() - - endpoint := fmt.Sprintf("datadog://%s?namespace=edge_state&tag=app:edge-state&tag=env:test", - l.LocalAddr().String()) - sink, err := InitWithURL("test", endpoint) - require.NoError(t, err) - require.IsType(t, &datadog.DogStatsdSink{}, sink) - - Inc("test_counter", 1) - - expectedMsg := "test.test_counter:1|c|#app:edge-state,env:test,service:test\n" - - var readBytes int - buf := make([]byte, 512) - readBytes, _, err = l.ReadFrom(buf) - require.NoError(t, err) - require.True(t, readBytes > 0) - - require.True(t, bytes.Equal(buf[0:readBytes], []byte(expectedMsg))) -} - -func TestDiscardSink(t *testing.T) { - sink, err := InitWithURL("test", "discard://") - require.NoError(t, err) - require.IsType(t, &metrics.BlackholeSink{}, sink) - - sink, err = InitWithURL("test", "") - require.NoError(t, err) - require.IsType(t, &metrics.BlackholeSink{}, sink) -} - -func TestInMemorySink(t *testing.T) { - sink, err := InitWithURL("test", "inmem://?interval=1s&retain=2s") - require.NoError(t, err) - require.IsType(t, &metrics.InmemSink{}, sink) - - Inc("test_counter", 1) - - met := sink.(*metrics.InmemSink).Data() - require.Len(t, met, 1) - require.Len(t, met[0].Counters, 1) - require.Contains(t, met[0].Counters, "test.test_counter") - require.Equal(t, "test.test_counter", met[0].Counters["test.test_counter"].Name) - require.Equal(t, 1, met[0].Counters["test.test_counter"].Count) -} - -func TestIncWithLabels(t *testing.T) { - sink, err := InitWithURL("test", "inmem://?interval=1s&retain=2s") - require.NoError(t, err) - require.IsType(t, &metrics.InmemSink{}, sink) - - Inc("test_counter", 1, L("tag", "value"), L("tag2", "value2")) - met := sink.(*metrics.InmemSink).Data() - require.Len(t, met, 1) - - require.Len(t, met[0].Counters, 1) - var incr metrics.SampledValue - for _, v := range met[0].Counters { - incr = v - break - } - - assert.Len(t, incr.Labels, 2) - for _, l := range incr.Labels { - switch l.Name { - case "tag": - assert.Equal(t, "value", l.Value) - case "tag2": - assert.Equal(t, "value2", l.Value) - default: - assert.Fail(t, "unexpected label value") - } - } -} diff --git a/mongoclient/db.go b/mongoclient/db.go index 1701ff9..9d0cd64 100644 --- a/mongoclient/db.go +++ b/mongoclient/db.go @@ -10,8 +10,6 @@ import ( "go.mongodb.org/mongo-driver/mongo/readpref" "github.com/sirupsen/logrus" - - "github.com/netlify/netlify-commons/nconf" ) const ( @@ -30,7 +28,7 @@ type Auth struct { type Config struct { AppName string - TLS *nconf.TLSConfig + TLS *TLSConfig Servers []string ReplSetName string ConnTimeout time.Duration diff --git a/mongoclient/options.go b/mongoclient/options.go index 53b0706..8d15db8 100644 --- a/mongoclient/options.go +++ b/mongoclient/options.go @@ -4,7 +4,6 @@ import ( "strings" "time" - "github.com/netlify/netlify-commons/nconf" "github.com/sirupsen/logrus" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readpref" @@ -12,7 +11,7 @@ import ( type Option func(opt *options.ClientOptions) error -func TLSOption(log logrus.FieldLogger, config nconf.TLSConfig) Option { +func TLSOption(log logrus.FieldLogger, config TLSConfig) Option { return func(opts *options.ClientOptions) error { if !config.Enabled { log.Debug("Skipping TLS config") diff --git a/mongoclient/tls.go b/mongoclient/tls.go new file mode 100644 index 0000000..92d1926 --- /dev/null +++ b/mongoclient/tls.go @@ -0,0 +1,100 @@ +package mongoclient + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "io/ioutil" + + "github.com/pkg/errors" +) + +type TLSConfig struct { + CAFiles []string `mapstructure:"ca_files" envconfig:"ca_files" json:"ca_files" yaml:"ca_files"` + KeyFile string `mapstructure:"key_file" split_words:"true" json:"key_file" yaml:"key_file"` + CertFile string `mapstructure:"cert_file" split_words:"true" json:"cert_file" yaml:"cert_file"` + + Cert string `mapstructure:"cert"` + Key string `mapstructure:"key"` + CA string `mapstructure:"ca"` + + Insecure bool `default:"false"` + Enabled bool `default:"false"` +} + +func (cfg TLSConfig) TLSConfig() (*tls.Config, error) { + var err error + + tlsConf := &tls.Config{ + MinVersion: tls.VersionTLS12, + InsecureSkipVerify: cfg.Insecure, + } + + // Load CA + if cfg.CA != "" { + tlsConf.RootCAs, err = LoadCAFromValue(cfg.CA) + } else if len(cfg.CAFiles) > 0 { + tlsConf.RootCAs, err = LoadCAFromFiles(cfg.CAFiles) + } else { + tlsConf.RootCAs, err = x509.SystemCertPool() + } + + if err != nil { + return nil, errors.Wrap(err, "Error setting up Root CA pool") + } + + // Load Certs if any + var cert tls.Certificate + if cfg.Cert != "" && cfg.Key != "" { + cert, err = LoadCertFromValues(cfg.Cert, cfg.Key) + tlsConf.Certificates = append(tlsConf.Certificates, cert) + } else if cfg.CertFile != "" && cfg.KeyFile != "" { + cert, err = LoadCertFromFiles(cfg.CertFile, cfg.KeyFile) + tlsConf.Certificates = append(tlsConf.Certificates, cert) + } + + if err != nil { + return nil, errors.Wrap(err, "Error loading certificate KeyPair") + } + + // Backwards compatibility: if TLS is not explicitly enabled, return nil if no certificate was provided + // Old code disabled TLS by not providing a certificate, which returned nil when calling TLSConfig() + if !cfg.Enabled && len(tlsConf.Certificates) == 0 { + return nil, nil + } + + return tlsConf, nil +} + +func LoadCertFromValues(certPEM, keyPEM string) (tls.Certificate, error) { + return tls.X509KeyPair([]byte(certPEM), []byte(keyPEM)) +} + +func LoadCertFromFiles(certFile, keyFile string) (tls.Certificate, error) { + return tls.LoadX509KeyPair(certFile, keyFile) +} + +func LoadCAFromFiles(cafiles []string) (*x509.CertPool, error) { + pool := x509.NewCertPool() + + for _, caFile := range cafiles { + caData, err := ioutil.ReadFile(caFile) + if err != nil { + return nil, err + } + + if !pool.AppendCertsFromPEM(caData) { + return nil, fmt.Errorf("Failed to add CA cert at %s", caFile) + } + } + + return pool, nil +} + +func LoadCAFromValue(ca string) (*x509.CertPool, error) { + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM([]byte(ca)) { + return nil, fmt.Errorf("Failed to add CA cert") + } + return pool, nil +} diff --git a/nconf/args.go b/nconf/args.go deleted file mode 100644 index 3051571..0000000 --- a/nconf/args.go +++ /dev/null @@ -1,141 +0,0 @@ -package nconf - -import ( - "fmt" - "strings" - - "github.com/netlify/netlify-commons/featureflag" - "github.com/netlify/netlify-commons/instrument" - "github.com/netlify/netlify-commons/metriks" - "github.com/netlify/netlify-commons/tracing" - "github.com/pkg/errors" - "github.com/sirupsen/logrus" - "github.com/spf13/cobra" - "github.com/spf13/pflag" -) - -type RootArgs struct { - Prefix string - ConfigFile string -} - -func (args *RootArgs) Setup(config interface{}, serviceName, version string) (logrus.FieldLogger, error) { - rootConfig, err := args.loadDefaultConfig() - if err != nil { - return nil, err - } - - log, err := ConfigureLogging(rootConfig.Log) - if err != nil { - return nil, errors.Wrap(err, "Failed to create the logger") - } - if version == "" { - version = "unknown" - } - log = log.WithField("version", version) - - if err := SetupBugSnag(rootConfig.BugSnag, version); err != nil { - return nil, errors.Wrap(err, "Failed to configure bugsnag") - } - - if err := metriks.Init(serviceName, rootConfig.Metrics); err != nil { - return nil, errors.Wrap(err, "Failed to configure metrics") - } - - // Handles the 'enabled' flag itself - tracing.Configure(&rootConfig.Tracing, log, serviceName) - - if err := featureflag.Init(rootConfig.FeatureFlag, log); err != nil { - return nil, errors.Wrap(err, "Failed to configure featureflags") - } - - if err := instrument.Init(rootConfig.Instrument, log); err != nil { - return nil, errors.Wrap(err, "Failed to configure instrument") - } - - if err := sendDatadogEvents(rootConfig.Metrics, serviceName, version); err != nil { - log.WithError(err).Error("Failed to send the startup events to datadog") - } - - if config != nil { - // second load the config for this project - if err := args.load(config); err != nil { - return log, errors.Wrap(err, "Failed to load the config object") - } - log.Debug("Loaded configuration") - } - return log, nil -} - -func (args *RootArgs) load(cfg interface{}) error { - loader := func(cfg interface{}) error { - return LoadFromEnv(args.Prefix, args.ConfigFile, cfg) - } - if !strings.HasSuffix(args.ConfigFile, ".env") { - loader = func(cfg interface{}) error { - return LoadConfigFromFile(args.ConfigFile, cfg) - } - } - return loader(cfg) -} - -func (args *RootArgs) MustSetup(config interface{}, serviceName, version string) logrus.FieldLogger { - logger, err := args.Setup(config, serviceName, version) - if err != nil { - if logger != nil { - logger.WithError(err).Fatal("Failed to setup configuration") - } else { - panic(fmt.Sprintf("Failed to setup configuration: %s", err.Error())) - } - } - - return logger -} - -func (args *RootArgs) loadDefaultConfig() (*RootConfig, error) { - c := DefaultConfig() - - if err := args.load(&c); err != nil { - return nil, errors.Wrap(err, "Failed to load the default configuration") - } - - return &c, nil -} - -func (args *RootArgs) AddFlags(cmd *cobra.Command) *cobra.Command { - cmd.Flags().AddFlag(args.ConfigFlag()) - cmd.Flags().AddFlag(args.PrefixFlag()) - return cmd -} - -func (args *RootArgs) ConfigFlag() *pflag.Flag { - return &pflag.Flag{ - Name: "config", - Shorthand: "c", - Usage: "A file to load configuration from, supported formats are env, json, and yaml", - Value: newStringValue("", &args.ConfigFile), - } -} - -func (args *RootArgs) PrefixFlag() *pflag.Flag { - return &pflag.Flag{ - Name: "prefix", - Shorthand: "p", - Usage: "A prefix to search for when looking for env vars", - Value: newStringValue("", &args.Prefix), - } -} - -type stringValue string - -func newStringValue(val string, p *string) *stringValue { - *p = val - return (*stringValue)(p) -} - -func (s *stringValue) Set(val string) error { - *s = stringValue(val) - return nil -} -func (s *stringValue) Type() string { return "string" } -func (s *stringValue) String() string { return string(*s) } diff --git a/nconf/args_test.go b/nconf/args_test.go deleted file mode 100644 index 03a0a14..0000000 --- a/nconf/args_test.go +++ /dev/null @@ -1,173 +0,0 @@ -package nconf - -import ( - "encoding/json" - "io/ioutil" - "os" - "testing" - "time" - - "github.com/sirupsen/logrus" - "github.com/spf13/cobra" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "gopkg.in/yaml.v3" -) - -func TestArgsLoad(t *testing.T) { - cfg := &struct { - Something string - Other int - Overridden string - }{ - Something: "default", - Overridden: "this should change", - } - - tmp, err := ioutil.TempFile("", "something*.env") - require.NoError(t, err) - cfgStr := ` -PF_OTHER=10 -PF_OVERRIDDEN=not-that -PF_LOG_LEVEL=debug -PF_LOG_QUOTE_EMPTY_FIELDS=true -` - require.NoError(t, ioutil.WriteFile(tmp.Name(), []byte(cfgStr), 0644)) - - args := RootArgs{ - Prefix: "pf", - ConfigFile: tmp.Name(), - } - - log, err := args.Setup(cfg, "", "") - require.NoError(t, err) - - // check that we did call configure the logger - assert.NotNil(t, log) - entry := log.(*logrus.Entry) - assert.Equal(t, logrus.DebugLevel, entry.Logger.Level) - assert.True(t, entry.Logger.Formatter.(*logrus.TextFormatter).QuoteEmptyFields) - - assert.Equal(t, "default", cfg.Something) - assert.Equal(t, 10, cfg.Other) - assert.Equal(t, "not-that", cfg.Overridden) -} - -func TestArgsAddToCmd(t *testing.T) { - args := new(RootArgs) - var called int - cmd := cobra.Command{ - Run: func(_ *cobra.Command, _ []string) { - assert.Equal(t, "PF", args.Prefix) - assert.Equal(t, "file.env", args.ConfigFile) - called++ - }, - } - cmd.PersistentFlags().AddFlag(args.ConfigFlag()) - cmd.PersistentFlags().AddFlag(args.PrefixFlag()) - cmd.SetArgs([]string{"--config", "file.env", "--prefix", "PF"}) - require.NoError(t, cmd.Execute()) - assert.Equal(t, 1, called) -} - -func TestArgsLoadDefault(t *testing.T) { - configVals := map[string]interface{}{ - "log": map[string]interface{}{ - "level": "debug", - "fields": map[string]interface{}{ - "something": 1, - }, - }, - "bugsnag": map[string]interface{}{ - "api_key": "secrets", - "project_package": "package", - }, - "metrics": map[string]interface{}{ - "enabled": true, - "port": 8125, - "tags": map[string]string{ - "env": "prod", - }, - }, - "tracing": map[string]interface{}{ - "enabled": true, - "port": "9125", - "enable_debug": true, - }, - "featureflag": map[string]interface{}{ - "key": "magicalkey", - "request_timeout": "10s", - "enabled": true, - }, - "instrument": map[string]interface{}{ - "key": "greatkey", - "enabled": true, - }, - } - - scenes := []struct { - ext string - enc func(v interface{}) ([]byte, error) - }{ - {"json", json.Marshal}, - {"yaml", yaml.Marshal}, - } - for _, s := range scenes { - t.Run(s.ext, func(t *testing.T) { - f, err := ioutil.TempFile("", "test-config-*."+s.ext) - require.NoError(t, err) - defer os.Remove(f.Name()) - - b, err := s.enc(&configVals) - require.NoError(t, err) - _, err = f.Write(b) - require.NoError(t, err) - - args := RootArgs{ - ConfigFile: f.Name(), - } - cfg, err := args.loadDefaultConfig() - require.NoError(t, err) - - // logging - assert.Equal(t, "debug", cfg.Log.Level) - assert.Equal(t, true, cfg.Log.QuoteEmptyFields) - assert.Equal(t, "", cfg.Log.File) - assert.Equal(t, false, cfg.Log.DisableColors) - assert.Equal(t, "", cfg.Log.TSFormat) - - assert.Len(t, cfg.Log.Fields, 1) - assert.EqualValues(t, 1, cfg.Log.Fields["something"]) - assert.Equal(t, false, cfg.Log.UseNewLogger) - - // bugsnag - assert.Equal(t, "", cfg.BugSnag.Environment) - assert.Equal(t, "secrets", cfg.BugSnag.APIKey) - assert.Equal(t, "package", cfg.BugSnag.ProjectPackage) - - // metrics - assert.Equal(t, true, cfg.Metrics.Enabled) - assert.Equal(t, "localhost", cfg.Metrics.Host) - assert.Equal(t, 8125, cfg.Metrics.Port) - assert.Equal(t, map[string]string{"env": "prod"}, cfg.Metrics.Tags) - - // tracing - assert.Equal(t, true, cfg.Tracing.Enabled) - assert.Equal(t, "localhost", cfg.Tracing.Host) - assert.Equal(t, "9125", cfg.Tracing.Port) - assert.Empty(t, cfg.Tracing.Tags) - assert.Equal(t, true, cfg.Tracing.EnableDebug) - - // featureflag - assert.Equal(t, "magicalkey", cfg.FeatureFlag.Key) - assert.Equal(t, 10*time.Second, cfg.FeatureFlag.RequestTimeout.Duration) - assert.Equal(t, true, cfg.FeatureFlag.Enabled) - assert.Equal(t, false, cfg.FeatureFlag.DisableEvents) - assert.Equal(t, "", cfg.FeatureFlag.RelayHost) - - // instrument - assert.Equal(t, "greatkey", cfg.Instrument.Key) - assert.Equal(t, true, cfg.Instrument.Enabled) - }) - } -} diff --git a/nconf/bugsnag.go b/nconf/bugsnag.go deleted file mode 100644 index cf97289..0000000 --- a/nconf/bugsnag.go +++ /dev/null @@ -1,35 +0,0 @@ -package nconf - -import ( - "github.com/bugsnag/bugsnag-go/v2" -) - -type BugSnagConfig struct { - Environment string - APIKey string `envconfig:"api_key" json:"api_key" yaml:"api_key"` - ProjectPackage string `envconfig:"project_package" json:"project_package" yaml:"project_package"` - NodeName string `envconfig:"node_name" json:"node_name" yaml:"node_name"` // If left unset, bugsnag will default to the value returned by os.Hostname -} - -func SetupBugSnag(config *BugSnagConfig, version string) error { - if config == nil || config.APIKey == "" { - return nil - } - - projectPackages := make([]string, 0, 2) - projectPackages = append(projectPackages, "main") - if config.ProjectPackage != "" { - projectPackages = append(projectPackages, config.ProjectPackage) - } - - bugsnag.Configure(bugsnag.Configuration{ - APIKey: config.APIKey, - ReleaseStage: config.Environment, - Hostname: config.NodeName, - AppVersion: version, - ProjectPackages: projectPackages, - PanicHandler: func() {}, // this is to disable panic handling. The lib was forking and restarting the process (causing races) - }) - - return nil -} diff --git a/nconf/configuration.go b/nconf/configuration.go deleted file mode 100644 index 7226f50..0000000 --- a/nconf/configuration.go +++ /dev/null @@ -1,127 +0,0 @@ -package nconf - -import ( - "encoding/json" - "fmt" - "io/ioutil" - "os" - "path/filepath" - "strings" - - "github.com/joho/godotenv" - "github.com/kelseyhightower/envconfig" - "github.com/netlify/netlify-commons/featureflag" - "github.com/netlify/netlify-commons/instrument" - "github.com/netlify/netlify-commons/metriks" - "github.com/netlify/netlify-commons/tracing" - "github.com/pkg/errors" - "github.com/spf13/viper" - "gopkg.in/yaml.v3" -) - -// ErrUnknownConfigFormat indicates the extension of the config file is not supported as a config source -type ErrUnknownConfigFormat struct { - ext string -} - -func (e *ErrUnknownConfigFormat) Error() string { - return fmt.Sprintf("Unknown config format: %s", e.ext) -} - -type RootConfig struct { - Log LoggingConfig - BugSnag *BugSnagConfig - Metrics metriks.Config - Tracing tracing.Config - FeatureFlag featureflag.Config - Instrument instrument.Config -} - -func DefaultConfig() RootConfig { - return RootConfig{ - Log: LoggingConfig{ - QuoteEmptyFields: true, - }, - Tracing: tracing.Config{ - Host: "localhost", - Port: "8126", - }, - Metrics: metriks.Config{ - Host: "localhost", - Port: 8125, - }, - } -} - -/* - Deprecated: This method relies on parsing the json/yaml to a map, then running it through mapstructure. - This required that both tags exist (annoying!). And so there is now LoadConfigFromFile. -*/ -// LoadFromFile will load the configuration from the specified file based on the file type -// There is only support for .json and .yml now -func LoadFromFile(configFile string, input interface{}) error { - if configFile == "" { - return nil - } - - switch { - case strings.HasSuffix(configFile, ".json"): - viper.SetConfigType("json") - case strings.HasSuffix(configFile, ".yaml"): - fallthrough - case strings.HasSuffix(configFile, ".yml"): - viper.SetConfigType("yaml") - } - viper.SetConfigFile(configFile) - - if err := viper.ReadInConfig(); err != nil && !os.IsNotExist(err) { - _, ok := err.(viper.ConfigFileNotFoundError) - if !ok { - return errors.Wrap(err, "reading configuration from files") - } - } - - return viper.Unmarshal(input) -} - -// LoadConfigFromFile will load the configuration from the specified file based on the file type -// There is only support for .json and .yml now. It will use the underlying json/yaml packages directly. -// meaning those should be the only required tags. -func LoadConfigFromFile(configFile string, input interface{}) error { - if configFile == "" { - return nil - } - - // read in all the bytes - data, err := ioutil.ReadFile(configFile) - if err != nil { - return err - } - - configExt := filepath.Ext(configFile) - switch configExt { - case ".json": - return json.Unmarshal(data, input) - case ".yaml", ".yml": - return yaml.Unmarshal(data, input) - } - return &ErrUnknownConfigFormat{configExt} -} - -func LoadFromEnv(prefix, filename string, face interface{}) error { - var err error - if filename == "" { - err = godotenv.Load() - if os.IsNotExist(err) { - err = nil - } - } else { - err = godotenv.Load(filename) - } - - if err != nil { - return err - } - - return envconfig.Process(prefix, face) -} diff --git a/nconf/configuration_test.go b/nconf/configuration_test.go deleted file mode 100644 index 0b0df29..0000000 --- a/nconf/configuration_test.go +++ /dev/null @@ -1,180 +0,0 @@ -package nconf - -import ( - "encoding/json" - "io/ioutil" - "os" - "strings" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "gopkg.in/yaml.v3" -) - -type testConfig struct { - Hero string - Villian string - Matchups map[string]string - Cities []string - - ShootingLocation string `json:"shooting_location" yaml:"shooting_location" split_words:"true"` -} - -func exampleConfig() testConfig { - return testConfig{ - Hero: "batman", - Villian: "joker", - Matchups: map[string]string{ - "batman": "superman", - "superman": "luther", - }, - Cities: []string{"gotham", "central", "star"}, - } -} - -func TestEnvLoadingNoFile(t *testing.T) { - env := os.Environ() - os.Clearenv() - defer func() { - for _, pair := range env { - parts := strings.SplitN(pair, "=", 2) - os.Setenv(parts[0], parts[1]) - } - }() - - os.Setenv("TEST_VILLIAN", "joker") - os.Setenv("TEST_HERO", "batman") - os.Setenv("TEST_MATCHUPS", "batman:superman,superman:luther") - os.Setenv("TEST_CITIES", "gotham,central,star") - - var results testConfig - assert.NoError(t, LoadFromEnv("test", "", &results)) - validateConfig(t, exampleConfig(), results) -} - -func TestEnvLoadingMissingFile(t *testing.T) { - err := LoadFromEnv("test", "should-exist.env", &struct{}{}) - assert.Error(t, err) -} - -func TestEnvLoadingFromFile(t *testing.T) { - os.Clearenv() - data := ` -TEST_VILLIAN=joker -TEST_HERO=batman -TEST_MATCHUPS=batman:superman,superman:luther -TEST_CITIES=gotham,central,star -` - filename := writeTestFile(t, "env", []byte(data)) - defer os.Remove(filename) - - var results testConfig - assert.NoError(t, LoadFromEnv("test", filename, &results)) - validateConfig(t, exampleConfig(), results) -} - -func TestFileLoadingNoFile(t *testing.T) { - var results = testConfig{ - Hero: "flash", - } - var expected = testConfig{ - Hero: "flash", - } - require.NoError(t, LoadFromFile("", &results)) - validateConfig(t, expected, results) -} - -func TestFileLoadJSON(t *testing.T) { - expected := exampleConfig() - bytes, err := json.Marshal(&expected) - require.NoError(t, err) - filename := writeTestFile(t, "json", bytes) - defer os.Remove(filename) - - var results testConfig - require.NoError(t, LoadConfigFromFile(filename, &results)) - validateConfig(t, expected, results) -} - -func TestFileLoadYAML(t *testing.T) { - expected := exampleConfig() - bytes, err := yaml.Marshal(&expected) - require.NoError(t, err) - filename := writeTestFile(t, "yaml", bytes) - defer os.Remove(filename) - - var results testConfig - require.NoError(t, LoadConfigFromFile(filename, &results)) - validateConfig(t, expected, results) -} - -func TestFileLoadWithSetFields(t *testing.T) { - expected := testConfig{ - Hero: "wonder woman", - } - // serailize it without the villain set - bytes, err := json.Marshal(&expected) - require.NoError(t, err) - filename := writeTestFile(t, "json", bytes) - defer os.Remove(filename) - - // set a default field - expected.Villian = "circe" - expected.Cities = []string{"gotham"} - - var results testConfig - require.NoError(t, LoadFromFile(filename, &results)) - - // this will overwrite ALL the values - assert.Equal(t, "", results.Villian) - assert.Equal(t, "wonder woman", results.Hero) - assert.Len(t, results.Cities, 0) -} - -func TestEnvLoadingWithTags(t *testing.T) { - data := ` -NESTED_WITH_TAG=loaded -WITH_TAG=not-loaded -NESTED_WITHOUT_TAG=loaded -NESTED_JAMMEDTOGETHER=loaded -` - filename := writeTestFile(t, "env", []byte(data)) - defer os.Remove(filename) - - results := struct { - Nested struct { - WithTag string `envconfig:"with_tag"` - WithoutTag string `split_words:"true"` - JammedTogether string - } - }{} - - require.NoError(t, LoadFromEnv("", filename, &results)) - assert.Equal(t, "loaded", results.Nested.WithTag) - assert.Equal(t, "loaded", results.Nested.JammedTogether) - assert.Equal(t, "loaded", results.Nested.WithoutTag) -} - -func writeTestFile(t *testing.T, ext string, data []byte) string { - f, err := ioutil.TempFile("", "test-*."+ext) - require.NoError(t, err) - - ioutil.WriteFile(f.Name(), data, 0644) - require.NoError(t, f.Close()) - return f.Name() -} - -func validateConfig(t *testing.T, expected testConfig, results testConfig) { - assert.Equal(t, expected.Hero, results.Hero) - assert.Equal(t, expected.Villian, results.Villian) - assert.Len(t, results.Cities, len(expected.Cities)) - for _, city := range expected.Cities { - assert.Contains(t, results.Cities, city) - } - - assert.Len(t, results.Matchups, len(expected.Matchups)) - for k, v := range expected.Matchups { - assert.Equal(t, v, results.Matchups[k]) - } -} diff --git a/nconf/events.go b/nconf/events.go deleted file mode 100644 index 14544cc..0000000 --- a/nconf/events.go +++ /dev/null @@ -1,62 +0,0 @@ -package nconf - -import ( - "fmt" - "os" - "os/signal" - "syscall" - - ddstatsd "github.com/DataDog/datadog-go/statsd" - "github.com/netlify/netlify-commons/metriks" - "github.com/pkg/errors" -) - -func sendDatadogEvents(conf metriks.Config, serviceName, version string) error { - if !conf.Enabled { - return nil - } - - client, err := ddstatsd.New(conf.StatsdAddr()) - if err != nil { - return errors.Wrap(err, "failed to connect to datadog agent") - } - - tags := []string{ - fmt.Sprintf("version:%s", version), - fmt.Sprintf("service:%s", serviceName), - } - host, err := os.Hostname() - if err != nil { - return errors.Wrap(err, "failed to get the hostname") - } - - key := "hostname" - if val := os.Getenv("KUBERNETES_PORT"); val != "" { - key = "pod" - } - tags = append(tags, fmt.Sprintf("%s:%s", key, host)) - - start := &ddstatsd.Event{ - Tags: append(tags, "event_type:startup"), - Title: fmt.Sprintf("Service Start: %s", serviceName), - Text: fmt.Sprintf("Service %s @ %s is starting", serviceName, version), - } - if err := client.Event(start); err != nil { - return errors.Wrap(err, "failed to send startup event") - } - - done := &ddstatsd.Event{ - Tags: append(tags, "event_type:shutdown"), - Title: fmt.Sprintf("Service Shutdown: %s", serviceName), - Text: fmt.Sprintf("Service '%s @ %s' is stopping", serviceName, version), - } - signals := make(chan os.Signal, 1) - signal.Notify(signals, os.Interrupt, syscall.SIGTERM, syscall.SIGINT) - - go func() { - <-signals - _ = client.Event(done) - }() - - return nil -} diff --git a/nconf/logging.go b/nconf/logging.go deleted file mode 100644 index b9507d7..0000000 --- a/nconf/logging.go +++ /dev/null @@ -1,61 +0,0 @@ -package nconf - -import ( - "os" - "time" - - "github.com/sirupsen/logrus" -) - -type LoggingConfig struct { - Level string `mapstructure:"log_level"` - File string `mapstructure:"log_file"` - DisableColors bool `mapstructure:"disable_colors" split_words:"true" json:"disable_colors" yaml:"disable_colors"` - QuoteEmptyFields bool `mapstructure:"quote_empty_fields" split_words:"true" json:"quote_empty_fields" yaml:"quote_empty_fields"` - TSFormat string `mapstructure:"ts_format" json:"ts_format" yaml:"ts_format"` - Fields map[string]interface{} `mapstructure:"fields"` - UseNewLogger bool `mapstructure:"use_new_logger" split_words:"true" json:"use_new_logger" yaml:"use_new_logger"` -} - -func ConfigureLogging(config LoggingConfig) (*logrus.Entry, error) { - logger := logrus.New() - - tsFormat := time.RFC3339Nano - if config.TSFormat != "" { - tsFormat = config.TSFormat - } - // always use the full timestamp - logger.SetFormatter(&logrus.TextFormatter{ - FullTimestamp: true, - DisableTimestamp: false, - TimestampFormat: tsFormat, - DisableColors: config.DisableColors, - QuoteEmptyFields: config.QuoteEmptyFields, - }) - - // use a file if you want - if config.File != "" { - f, errOpen := os.OpenFile(config.File, os.O_RDWR|os.O_APPEND|os.O_CREATE, 0664) - if errOpen != nil { - return nil, errOpen - } - logger.SetOutput(f) - logger.Infof("Set output file to %s", config.File) - } - - if config.Level != "" { - level, err := logrus.ParseLevel(config.Level) - if err != nil { - return nil, err - } - logger.SetLevel(level) - logger.Debug("Set log level to: " + logger.GetLevel().String()) - } - - f := logrus.Fields{} - for k, v := range config.Fields { - f[k] = v - } - - return logger.WithFields(f), nil -} diff --git a/nconf/timeout.go b/nconf/timeout.go deleted file mode 100644 index a4d8437..0000000 --- a/nconf/timeout.go +++ /dev/null @@ -1,32 +0,0 @@ -package nconf - -import ( - "github.com/netlify/netlify-commons/util" -) - -// HTTPServerTimeoutConfig represents common HTTP server timeout values -type HTTPServerTimeoutConfig struct { - // Read = http.Server.ReadTimeout - Read util.Duration `mapstructure:"read"` - // Write = http.Server.WriteTimeout - Write util.Duration `mapstructure:"write"` - // Handler = http.TimeoutHandler (or equivalent). - // The maximum amount of time a server handler can take. - Handler util.Duration `mapstructure:"handler"` -} - -// HTTPClientTimeoutConfig represents common HTTP client timeout values -type HTTPClientTimeoutConfig struct { - // Dial = net.Dialer.Timeout - Dial util.Duration `mapstructure:"dial"` - // KeepAlive = net.Dialer.KeepAlive - KeepAlive util.Duration `mapstructure:"keep_alive" split_words:"true" json:"keep_alive" yaml:"keep_alive"` - - // TLSHandshake = http.Transport.TLSHandshakeTimeout - TLSHandshake util.Duration `mapstructure:"tls_handshake" split_words:"true" json:"tls_handshake" yaml:"tls_handshake"` - // ResponseHeader = http.Transport.ResponseHeaderTimeout - ResponseHeader util.Duration `mapstructure:"response_header" split_words:"true" json:"response_header" yaml:"response_header"` - // Total = http.Client.Timeout or equivalent - // The maximum amount of time a client request can take. - Total util.Duration `mapstructure:"total"` -} diff --git a/nconf/timeout_test.go b/nconf/timeout_test.go deleted file mode 100644 index f3583c0..0000000 --- a/nconf/timeout_test.go +++ /dev/null @@ -1,53 +0,0 @@ -package nconf - -import ( - "encoding/json" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "gopkg.in/yaml.v3" -) - -func TestParseTimeoutValues(t *testing.T) { - raw := map[string]string{ - "dial": "10s", - "keep_alive": "11s", - "tls_handshake": "12s", - "response_header": "13s", - "total": "14s", - } - - // write it to json & yaml - // then load it through the RootArgs.load - scenes := []struct { - name string - enc func(interface{}) ([]byte, error) - dec func([]byte, interface{}) error - }{ - {"json", json.Marshal, json.Unmarshal}, - {"yaml", yaml.Marshal, yaml.Unmarshal}, - } - for _, s := range scenes { - t.Run(s.name, func(t *testing.T) { - bs, err := s.enc(&raw) - require.NoError(t, err) - - var cfg HTTPClientTimeoutConfig - - require.NoError(t, s.dec(bs, &cfg)) - - assert.Equal(t, "10s", cfg.Dial.String()) - assert.Equal(t, "11s", cfg.KeepAlive.String()) - assert.Equal(t, "12s", cfg.TLSHandshake.String()) - assert.Equal(t, "13s", cfg.ResponseHeader.String()) - assert.Equal(t, "14s", cfg.Total.String()) - assert.Equal(t, 10*time.Second, cfg.Dial.Duration) - assert.Equal(t, 11*time.Second, cfg.KeepAlive.Duration) - assert.Equal(t, 12*time.Second, cfg.TLSHandshake.Duration) - assert.Equal(t, 13*time.Second, cfg.ResponseHeader.Duration) - assert.Equal(t, 14*time.Second, cfg.Total.Duration) - }) - } -} diff --git a/router/errors.go b/router/errors.go deleted file mode 100644 index 5de1b63..0000000 --- a/router/errors.go +++ /dev/null @@ -1,162 +0,0 @@ -package router - -import ( - "fmt" - "net/http" - "reflect" - - "github.com/bugsnag/bugsnag-go/v2" - "github.com/netlify/netlify-commons/metriks" - "github.com/netlify/netlify-commons/tracing" - "github.com/sirupsen/logrus" -) - -// HTTPError is an error with a message and an HTTP status code. -type HTTPError struct { - Code int `json:"code"` - Message string `json:"msg"` - JSON interface{} `json:"json,omitempty"` - InternalError error `json:"-"` - InternalMessage string `json:"-"` - ErrorID string `json:"error_id,omitempty"` - Fields logrus.Fields `json:"-"` -} - -// BadRequestError creates a 400 HTTP error -func BadRequestError(fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusBadRequest, fmtString, args...) -} - -// InternalServerError creates a 500 HTTP error -func InternalServerError(fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusInternalServerError, fmtString, args...) -} - -// NotFoundError creates a 404 HTTP error -func NotFoundError(fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusNotFound, fmtString, args...) -} - -// UnauthorizedError creates a 401 HTTP error -func UnauthorizedError(fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusUnauthorized, fmtString, args...) -} - -// UnavailableServiceError creates a 503 HTTP error -func UnavailableServiceError(fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusServiceUnavailable, fmtString, args...) -} - -// Error will describe the HTTP error in text -func (e *HTTPError) Error() string { - if e.InternalMessage != "" { - return e.InternalMessage - } - return fmt.Sprintf("%d: %s", e.Code, e.Message) -} - -// Cause will return the root cause error -func (e *HTTPError) Cause() error { - if e.InternalError != nil { - return e.InternalError - } - return e -} - -// WithJSONError will add json details to the error -func (e *HTTPError) WithJSONError(json interface{}) *HTTPError { - e.JSON = json - return e -} - -// WithInternalError will add internal error information to an error -func (e *HTTPError) WithInternalError(err error) *HTTPError { - e.InternalError = err - return e -} - -// WithInternalMessage will add and internal message to an error -func (e *HTTPError) WithInternalMessage(fmtString string, args ...interface{}) *HTTPError { - e.InternalMessage = fmt.Sprintf(fmtString, args...) - return e -} - -// WithFields will add fields to an error message -func (e *HTTPError) WithFields(fields logrus.Fields) *HTTPError { - for key, value := range fields { - e.Fields[key] = value - } - return e -} - -// WithFields will add fields to an error message -func (e *HTTPError) WithField(key string, value interface{}) *HTTPError { - e.Fields[key] = value - return e -} - -func httpError(code int, fmtString string, args ...interface{}) *HTTPError { - return &HTTPError{ - Code: code, - Message: fmt.Sprintf(fmtString, args...), - Fields: make(logrus.Fields), - } -} - -// HandleError will handle any error. If it is of type *HTTPError then it will -// log anything of a 50x or an InternalError. It will write the right error response -// to the client. This way if you return a BadRequestError, it will simply write to the client. -// Any non-HTTPError will be treated as unhandled and result in a 50x -func HandleError(err error, w http.ResponseWriter, r *http.Request) { - if err == nil || reflect.ValueOf(err).IsNil() { - return - } - - log := tracing.GetLogger(r) - errorID := tracing.GetRequestID(r) - - var notifyBugsnag bool - - switch e := err.(type) { - case *HTTPError: - log = log.WithFields(e.Fields) - - e.ErrorID = errorID - if e.Code >= http.StatusInternalServerError { - notifyBugsnag = true - elog := log.WithError(e) - if e.InternalError != nil { - elog = elog.WithField("internal_err", e.InternalError.Error()) - } - - elog.Errorf("internal server error: %s", e.InternalMessage) - } else if e.InternalError != nil { - notifyBugsnag = true - log.WithError(e).Infof("unexpected error: %s", e.InternalMessage) - } - - if jsonErr := SendJSON(w, e.Code, e); jsonErr != nil { - log.WithError(jsonErr).Error("Failed to write the JSON error response") - } - default: - notifyBugsnag = true - metriks.Inc("unhandled_errors", 1) - log.WithError(e).Errorf("Unhandled server error: %s", e.Error()) - // hide real error details from response to prevent info leaks - w.WriteHeader(http.StatusInternalServerError) - if _, writeErr := w.Write([]byte(`{"code":500,"msg":"Internal server error","error_id":"` + errorID + `"}`)); writeErr != nil { - log.WithError(writeErr).Error("Error writing generic error message") - } - } - - if notifyBugsnag { - bugsnag.Notify(err, r, r.Context(), bugsnag.MetaData{ - "meta": map[string]interface{}{ - "error_id": errorID, - "error_msg": err.Error(), - "status_code": http.StatusInternalServerError, - "unhandled": true, - }, - }) - } -} diff --git a/router/errors_test.go b/router/errors_test.go deleted file mode 100644 index e58a743..0000000 --- a/router/errors_test.go +++ /dev/null @@ -1,252 +0,0 @@ -package router - -import ( - "errors" - "fmt" - "io/ioutil" - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/armon/go-metrics" - "github.com/bugsnag/bugsnag-go/v2" - "github.com/netlify/netlify-commons/metriks" - "github.com/netlify/netlify-commons/tracing" - "github.com/sirupsen/logrus/hooks/test" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestHandleError_ErrorIsNil(t *testing.T) { - logger, loggerOutput := test.NewNullLogger() - w, r, _ := tracing.NewTracer( - httptest.NewRecorder(), - httptest.NewRequest(http.MethodGet, "/", nil), - logger, - "test", - "test", - ) - - HandleError(nil, w, r) - - assert.Empty(t, loggerOutput.AllEntries()) - assert.Empty(t, w.Header()) -} - -func TestHandleError_ErrorIsNilPointerToTypeHTTPError(t *testing.T) { - logger, loggerOutput := test.NewNullLogger() - w, r, _ := tracing.NewTracer( - httptest.NewRecorder(), - httptest.NewRequest(http.MethodGet, "/", nil), - logger, - "test", - "test", - ) - - h := func(_ http.ResponseWriter, _ *http.Request) *HTTPError { - return nil - } - - HandleError(h(w, r), w, r) - - assert.Empty(t, loggerOutput.AllEntries()) - assert.Empty(t, w.Header()) -} - -func TestHandleError_ErrorIsNilInterface(t *testing.T) { - logger, loggerOutput := test.NewNullLogger() - w, r, _ := tracing.NewTracer( - httptest.NewRecorder(), - httptest.NewRequest(http.MethodGet, "/", nil), - logger, - "test", - "test", - ) - - h := func(_ http.ResponseWriter, _ *http.Request) error { - return nil - } - - HandleError(h(w, r), w, r) - - assert.Empty(t, loggerOutput.AllEntries()) - assert.Empty(t, w.Header()) -} - -func TestHandleError_StandardError(t *testing.T) { - logger, loggerOutput := test.NewNullLogger() - w, r, _ := tracing.NewTracer( - httptest.NewRecorder(), - httptest.NewRequest(http.MethodGet, "/", nil), - logger, - "test", - "test", - ) - - HandleError(errors.New("random error"), w, r) - - require.Len(t, loggerOutput.AllEntries(), 1) - assert.Equal(t, "Unhandled server error: random error", loggerOutput.AllEntries()[0].Message) - assert.Empty(t, w.Header()) -} - -func TestHandleError_HTTPError(t *testing.T) { - logger, loggerOutput := test.NewNullLogger() - recorder := httptest.NewRecorder() - w, r, _ := tracing.NewTracer( - recorder, - httptest.NewRequest(http.MethodGet, "/", nil), - logger, - "test", - "test", - ) - - httpErr := &HTTPError{ - Code: http.StatusInternalServerError, - Message: http.StatusText(http.StatusInternalServerError), - InternalError: errors.New("random error"), - InternalMessage: "Something unexpected happened", - } - - HandleError(httpErr, w, r) - - resp := recorder.Result() - b, err := ioutil.ReadAll(resp.Body) - require.NoError(t, err) - - expectedBody := fmt.Sprintf(`{"code":500,"msg":"Internal Server Error","error_id":"%s"}`, tracing.GetRequestID(r)) - assert.Equal(t, expectedBody, string(b)) - - require.Len(t, loggerOutput.AllEntries(), 1) - assert.Equal(t, "internal server error: "+httpErr.InternalMessage, loggerOutput.AllEntries()[0].Message) -} - -func TestHandleError_HttpErrorWithFields(t *testing.T) { - logger, loggerOutput := test.NewNullLogger() - recorder := httptest.NewRecorder() - w, r, _ := tracing.NewTracer( - recorder, - httptest.NewRequest(http.MethodGet, "/", nil), - logger, - "test", - "test", - ) - - httpErr := httpError(http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) - - httpErr.WithFields(map[string]interface{}{ - "a": "1", - "b": "2", - }) - - httpErr.WithField("c", "3") - httpErr.WithField("a", "0") - - HandleError(httpErr, w, r) - - require.Len(t, loggerOutput.AllEntries(), 1) - entry := loggerOutput.LastEntry() - assert.Equal(t, entry.Data["a"], "0") - assert.Equal(t, entry.Data["b"], "2") - assert.Equal(t, entry.Data["c"], "3") -} - -func TestHandleError_NoLogForNormalErrors(t *testing.T) { - logger, loggerOutput := test.NewNullLogger() - recorder := httptest.NewRecorder() - w, r, _ := tracing.NewTracer( - recorder, - httptest.NewRequest(http.MethodGet, "/", nil), - logger, - "test", - "test", - ) - - httpErr := BadRequestError("not found yo.") - - HandleError(httpErr, w, r) - - resp := recorder.Result() - b, err := ioutil.ReadAll(resp.Body) - require.NoError(t, err) - - expectedBody := fmt.Sprintf(`{"code":400,"msg":"not found yo.","error_id":"%s"}`, tracing.GetRequestID(r)) - assert.Equal(t, expectedBody, string(b)) - - // we shouldn't log anything, this is a normal error - require.Len(t, loggerOutput.AllEntries(), 0) -} - -type OtherError struct { - error string -} - -func (e *OtherError) Error() string { - return e.error -} - -func TestHandleError_ErrorIsNilPointerToTypeOtherError(t *testing.T) { - logger, loggerOutput := test.NewNullLogger() - w, r, _ := tracing.NewTracer( - httptest.NewRecorder(), - httptest.NewRequest(http.MethodGet, "/", nil), - logger, - "test", - "test", - ) - - var oe *OtherError - - HandleError(oe, w, r) - - require.Len(t, loggerOutput.AllEntries(), 0) - assert.Empty(t, w.Header()) -} - -func TestHandleError_ErrorGoesToBugsnag(t *testing.T) { - var called int - - bugsnag.OnBeforeNotify(func(event *bugsnag.Event, config *bugsnag.Configuration) error { - called++ - require.NotNil(t, event) - assert.NotNil(t, event.Ctx) - - assert.NotNil(t, config) - return errors.New("this should stop us from sending to bugsnag") - }) - - HandleError( - errors.New("this is an error"), - httptest.NewRecorder(), - httptest.NewRequest(http.MethodGet, "/", nil), - ) - assert.Equal(t, 1, called) - - // we shouldn't be notified of regular errors - HandleError( - NotFoundError("not found"), - httptest.NewRecorder(), - httptest.NewRequest(http.MethodGet, "/", nil), - ) - assert.Equal(t, 1, called) - - // we should be notified of internal server errors - HandleError( - InternalServerError("this is an error"), - httptest.NewRecorder(), - httptest.NewRequest(http.MethodGet, "/", nil), - ) - assert.Equal(t, 2, called) -} -func TestHandleError_ErrorEmitsMetric(t *testing.T) { - sink := metrics.NewInmemSink(time.Minute, time.Minute) - require.NoError(t, metriks.InitWithSink(t.Name(), sink)) - w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, "/", nil) - HandleError(errors.New("this is an error"), w, r) - - assert.Equal(t, http.StatusInternalServerError, w.Code) - - assert.Len(t, sink.Data(), 1) -} diff --git a/router/helpers.go b/router/helpers.go deleted file mode 100644 index 86cfb7e..0000000 --- a/router/helpers.go +++ /dev/null @@ -1,21 +0,0 @@ -package router - -import ( - "encoding/json" - "fmt" - "net/http" - - "github.com/pkg/errors" -) - -// SendJSON will write the response object as JSON -func SendJSON(w http.ResponseWriter, status int, obj interface{}) error { - w.Header().Set("Content-Type", "application/json") - b, err := json.Marshal(obj) - if err != nil { - return errors.Wrap(err, fmt.Sprintf("Error encoding json response: %v", obj)) - } - w.WriteHeader(status) - _, err = w.Write(b) - return err -} diff --git a/router/middleware.go b/router/middleware.go deleted file mode 100644 index 2cd12b9..0000000 --- a/router/middleware.go +++ /dev/null @@ -1,156 +0,0 @@ -package router - -import ( - "bufio" - "bytes" - "fmt" - "net/http" - "os" - "regexp" - "runtime/debug" - "strings" - - "github.com/bugsnag/bugsnag-go/v2" - "github.com/netlify/netlify-commons/tracing" - "github.com/opentracing/opentracing-go" - "github.com/sirupsen/logrus" - "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" -) - -var bearerRegexp = regexp.MustCompile(`^(?:B|b)earer (\S+$)`) - -const versionHeaderTempl = "X-NF-%s-Version" - -type Middleware func(http.Handler) http.Handler - -func MiddlewareFunc(f func(w http.ResponseWriter, r *http.Request, next http.Handler)) Middleware { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - f(w, r, next) - }) - } -} - -func VersionHeader(serviceName, version string) Middleware { - if version == "" { - version = "unknown" - } - return MiddlewareFunc(func(w http.ResponseWriter, r *http.Request, next http.Handler) { - w.Header().Set(fmt.Sprintf(versionHeaderTempl, strings.ToUpper(serviceName)), version) - next.ServeHTTP(w, r) - }) -} - -func CheckAuth(secret string) Middleware { - return MiddlewareFunc(func(w http.ResponseWriter, r *http.Request, next http.Handler) { - if secret != "" { - authHeader := r.Header.Get("Authorization") - if authHeader == "" { - HandleError(UnauthorizedError("This endpoint requires a Bearer token"), w, r) - return - } - - matches := bearerRegexp.FindStringSubmatch(authHeader) - if len(matches) != 2 { - HandleError(UnauthorizedError("This endpoint requires a Bearer token"), w, r) - return - } - - if secret != matches[1] { - HandleError(UnauthorizedError("This endpoint requires a Bearer token"), w, r) - return - } - } - - next.ServeHTTP(w, r) - }) -} - -// Recoverer is a middleware that recovers from panics, logs the panic (and a -// backtrace), and returns a HTTP 500 (Internal Server Error) status if -// possible. Recoverer prints a request ID if one is provided. -func Recoverer(errLog logrus.FieldLogger) Middleware { - return MiddlewareFunc(func(w http.ResponseWriter, r *http.Request, next http.Handler) { - reqID := tracing.GetRequestID(r) - - defer func() { - if rvr := recover(); rvr != nil { - if errLog == nil { - logger := logrus.New() - logger.Out = os.Stderr - errLog = logrus.NewEntry(logger) - } - panicLog := errLog.WithField("request_id", reqID) - - stack := debug.Stack() - scanner := bufio.NewScanner(bytes.NewReader(stack)) - - var lineID int - panicLog.WithField("trace_line", lineID).Errorf("Panic: %+v", rvr) - for scanner.Scan() { - lineID++ - panicLog.WithField("trace_line", lineID).Errorf(scanner.Text()) - } - - se := &HTTPError{ - Code: http.StatusInternalServerError, - Message: http.StatusText(http.StatusInternalServerError), - } - HandleError(se, w, r) - - // in the event of a panic none of the normal shutdown code is called - if span := opentracing.SpanFromContext(r.Context()); span != nil { - span.SetTag("error_id", reqID) - span.SetTag(ext.ErrorType, "panic") - span.SetTag(ext.HTTPCode, http.StatusInternalServerError) - - if err, ok := rvr.(error); ok { - span.SetTag(ext.ErrorMsg, err.Error()) - } - - defer span.Finish() - } - - if tr := tracing.GetFromContext(r.Context()); tr != nil { - tr.Finish() - } - } - }() - - next.ServeHTTP(w, r) - }) -} - -func HealthCheck(route string, f APIHandler) Middleware { - return MiddlewareFunc(func(w http.ResponseWriter, r *http.Request, next http.Handler) { - if r.URL.Path == route { - if f == nil { - w.WriteHeader(http.StatusOK) - return - } - - if err := f(w, r); err != nil { - HandleError(err, w, r) - } - - return - } - next.ServeHTTP(w, r) - }) -} - -func BugSnag() Middleware { - return MiddlewareFunc(func(w http.ResponseWriter, r *http.Request, next http.Handler) { - ctx := bugsnag.StartSession(r.Context()) - defer bugsnag.AutoNotify(ctx) - next.ServeHTTP(w, r.WithContext(ctx)) - }) -} - -func TrackAllRequests(log logrus.FieldLogger, service string) Middleware { - return MiddlewareFunc(func(w http.ResponseWriter, r *http.Request, next http.Handler) { - // This is to maintain some legacy span work. It will cause the APM requests - // to show up as the method on the top level - tracing.TrackRequest(w, r, log, service, r.Method, next) - }) -} diff --git a/router/middleware_test.go b/router/middleware_test.go deleted file mode 100644 index 0e0f185..0000000 --- a/router/middleware_test.go +++ /dev/null @@ -1,154 +0,0 @@ -package router - -import ( - "errors" - "fmt" - "net/http" - "net/http/httptest" - "testing" - - "github.com/netlify/netlify-commons/tracing" - "github.com/opentracing/opentracing-go" - "github.com/opentracing/opentracing-go/mocktracer" - "github.com/sirupsen/logrus" - "github.com/sirupsen/logrus/hooks/test" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" -) - -func TestCheckAuth(t *testing.T) { - validKey := "testkey" - invalidKey := "nopekey" - emptyKey := "" - - makeRequest := func(req *http.Request) *httptest.ResponseRecorder { - r := New(logrus.WithField("test", "CheckAuth")) - r.Use(CheckAuth(validKey)) - r.Get("/", func(w http.ResponseWriter, r *http.Request) error { - return nil - }) - rec := httptest.NewRecorder() - r.ServeHTTP(rec, req) - return rec - } - - t.Run("valid key", func(t *testing.T) { - req, err := http.NewRequest("GET", "/", nil) - require.NoError(t, err) - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", validKey)) - rsp := makeRequest(req) - assert.Equal(t, http.StatusOK, rsp.Code) - }) - t.Run("lower case bearer", func(t *testing.T) { - req, err := http.NewRequest("GET", "/", nil) - require.NoError(t, err) - req.Header.Set("Authorization", fmt.Sprintf("bearer %s", validKey)) - rsp := makeRequest(req) - assert.Equal(t, http.StatusOK, rsp.Code) - }) - - t.Run("invalid key", func(t *testing.T) { - req, err := http.NewRequest("GET", "/", nil) - require.NoError(t, err) - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", invalidKey)) - rsp := makeRequest(req) - assert.Equal(t, http.StatusUnauthorized, rsp.Code) - }) - t.Run("no header", func(t *testing.T) { - req, err := http.NewRequest("GET", "/", nil) - require.NoError(t, err) - rsp := makeRequest(req) - assert.Equal(t, http.StatusUnauthorized, rsp.Code) - }) - t.Run("empty key", func(t *testing.T) { - req, err := http.NewRequest("GET", "/", nil) - require.NoError(t, err) - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", emptyKey)) - rsp := makeRequest(req) - assert.Equal(t, http.StatusUnauthorized, rsp.Code) - }) - t.Run("invalid Authorization value", func(t *testing.T) { - req, err := http.NewRequest("GET", "/", nil) - require.NoError(t, err) - req.Header.Set("Authorization", fmt.Sprintf("what even is this %s", invalidKey)) - rsp := makeRequest(req) - assert.Equal(t, http.StatusUnauthorized, rsp.Code) - }) - -} - -func TestRecoveryLogging(t *testing.T) { - logger, hook := test.NewNullLogger() - - mw := Recoverer(logger) - - rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "http://doesntmatter.com", nil) - req.Header.Set(tracing.HeaderRequestUUID, "123456") - - // this should be captured by the recorder - handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - panic(errors.New("because I should")) - })) - - handler.ServeHTTP(rec, req) - require.NotEmpty(t, hook.AllEntries) - var lineID int - for _, e := range hook.AllEntries() { - assert.Equal(t, "123456", e.Data["request_id"], "missing the request_id: %v", e.Data) - assert.Equal(t, lineID, e.Data["trace_line"], "trace_line isn't in order: %v", e.Data) - lineID++ - } -} - -func TestRecoveryTracing(t *testing.T) { - mw := Recoverer(logrus.New()) - handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - panic(errors.New("because I should")) - })) - - mtracer := mocktracer.New() - r := httptest.NewRequest(http.MethodGet, "/", nil) - r.Header.Set(tracing.HeaderRequestUUID, "123456") - ctx := opentracing.ContextWithSpan(r.Context(), mtracer.StartSpan(t.Name())) - r = r.WithContext(ctx) - w := httptest.NewRecorder() - handler.ServeHTTP(w, r) - - finished := mtracer.FinishedSpans() - assert.Len(t, finished, 1) - - tags := finished[0].Tags() - assert.Len(t, tags, 4) - assert.Equal(t, "123456", tags["error_id"]) - assert.Equal(t, "because I should", tags[ext.ErrorMsg]) - assert.Equal(t, "panic", tags[ext.ErrorType]) - assert.Equal(t, 500, tags[ext.HTTPCode]) -} - -func TestRecoveryInternalTracer(t *testing.T) { - logger, hook := test.NewNullLogger() - w, r, _ := tracing.NewTracer( - httptest.NewRecorder(), - httptest.NewRequest(http.MethodGet, "/", nil), - logger, - t.Name(), - "some_resource", - ) - mw := Recoverer(logrus.New()) - handler := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - panic(errors.New("because I should")) - })) - handler.ServeHTTP(w, r) - - var found bool - for _, e := range hook.Entries { - found = e.Message == "Completed Request" - if found { - assert.Equal(t, 500, e.Data["status_code"]) - break - } - } - assert.True(t, found) -} diff --git a/router/options.go b/router/options.go deleted file mode 100644 index 980c8bb..0000000 --- a/router/options.go +++ /dev/null @@ -1,37 +0,0 @@ -package router - -type Option func(r *chiWrapper) - -func OptEnableCORS(r *chiWrapper) { - r.enableCORS = true -} - -func OptHealthCheck(path string, checker APIHandler) Option { - return func(r *chiWrapper) { - r.healthEndpoint = path - r.healthHandler = checker - } -} - -func OptVersionHeader(svcName, version string) Option { - return func(r *chiWrapper) { - if version == "" { - version = "unknown" - } - r.version = version - r.svcName = svcName - } -} - -func OptEnableTracing(svcName string) Option { - return func(r *chiWrapper) { - r.svcName = svcName - r.enableTracing = true - } -} - -func OptRecoverer() Option { - return func(r *chiWrapper) { - r.enableRecover = true - } -} diff --git a/router/router.go b/router/router.go deleted file mode 100644 index 1b7f91c..0000000 --- a/router/router.go +++ /dev/null @@ -1,195 +0,0 @@ -package router - -import ( - "net/http" - "strings" - - "github.com/netlify/netlify-commons/tracing" - "github.com/rs/cors" - - "github.com/go-chi/chi" - "github.com/sebest/xff" - "github.com/sirupsen/logrus" -) - -type chiWrapper struct { - chi chi.Router - - version string - svcName string - tracingPrefix string - rootLogger logrus.FieldLogger - - healthEndpoint string - healthHandler APIHandler - - enableTracing bool - enableCORS bool - enableRecover bool -} - -// Router wraps the chi router to make it slightly more accessible -type Router interface { - // Use appends one middleware onto the Router stack. - Use(fn Middleware) - - // With adds an inline middleware for an endpoint handler. - With(fn Middleware) Router - - // Route mounts a sub-Router along a `pattern`` string. - Route(pattern string, fn func(r Router)) - - // Method adds a routes for a `pattern` that matches the `method` HTTP method. - Method(method, pattern string, h APIHandler) - - // HTTP-method routing along `pattern` - Delete(pattern string, h APIHandler) - Get(pattern string, h APIHandler) - Post(pattern string, h APIHandler) - Put(pattern string, h APIHandler) - - // Mount attaches another http.Handler along ./pattern/* - Mount(pattern string, h http.Handler) - - ServeHTTP(http.ResponseWriter, *http.Request) -} - -// New creates a router with sensible defaults (xff, request id, cors) -func New(log logrus.FieldLogger, options ...Option) Router { - r := &chiWrapper{ - chi: chi.NewRouter(), - version: "unknown", - rootLogger: log, - } - - xffmw, _ := xff.Default() - r.Use(xffmw.Handler) - for _, opt := range options { - opt(r) - } - - if r.enableRecover { - r.Use(Recoverer(log)) - } - r.Use(VersionHeader(r.svcName, r.version)) - if r.enableCORS { - corsMiddleware := cors.New(cors.Options{ - AllowedMethods: []string{"GET", "POST", "PATCH", "PUT", "DELETE"}, - AllowedHeaders: []string{"Accept", "Authorization", "Content-Type"}, - ExposedHeaders: []string{"Link", "X-Total-Count"}, - AllowCredentials: true, - }) - r.Use(corsMiddleware.Handler) - } - - if r.healthEndpoint != "" { - r.Use(HealthCheck(r.healthEndpoint, r.healthHandler)) - } - - return r -} - -// Route allows creating a generic route -func (r *chiWrapper) Route(pattern string, fn func(Router)) { - r.chi.Route(pattern, func(c chi.Router) { - wrapper := new(chiWrapper) - *wrapper = *r - wrapper.chi = c - wrapper.tracingPrefix = sanitizePattern(pattern) - fn(wrapper) - }) -} - -// Method adds a routes for a `pattern` that matches the `method` HTTP method. -func (r *chiWrapper) Method(method, pattern string, h APIHandler) { - r.chi.Method(method, pattern, r.traceRequest(method, pattern, h)) -} - -// Get adds a GET route -func (r *chiWrapper) Get(pattern string, fn APIHandler) { - r.chi.Get(pattern, r.traceRequest(http.MethodGet, pattern, fn)) -} - -// Post adds a POST route -func (r *chiWrapper) Post(pattern string, fn APIHandler) { - r.chi.Post(pattern, r.traceRequest(http.MethodPost, pattern, fn)) -} - -// Put adds a PUT route -func (r *chiWrapper) Put(pattern string, fn APIHandler) { - r.chi.Put(pattern, r.traceRequest(http.MethodPut, pattern, fn)) -} - -// Delete adds a DELETE route -func (r *chiWrapper) Delete(pattern string, fn APIHandler) { - r.chi.Delete(pattern, r.traceRequest(http.MethodDelete, pattern, fn)) -} - -// WithBypass adds an inline chi middleware for an endpoint handler -func (r *chiWrapper) With(fn Middleware) Router { - r.chi = r.chi.With(fn) - return r -} - -// UseBypass appends one chi middleware onto the Router stack -func (r *chiWrapper) Use(fn Middleware) { - r.chi.Use(fn) -} - -// ServeHTTP will serve a request -func (r *chiWrapper) ServeHTTP(w http.ResponseWriter, req *http.Request) { - r.chi.ServeHTTP(w, req) -} - -// Mount attaches another http.Handler along ./pattern/* -func (r *chiWrapper) Mount(pattern string, h http.Handler) { - if r.enableTracing { - h = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - tracing.TrackRequest(w, req, r.rootLogger, r.svcName, pattern, h) - }) - } - r.chi.Mount(pattern, h) -} - -// ======================================= -// HTTP handler with custom error payload -// ======================================= - -type APIHandler func(w http.ResponseWriter, r *http.Request) error - -func HandlerFunc(fn APIHandler) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if err := fn(w, r); err != nil { - HandleError(err, w, r) - } - } -} - -func (r *chiWrapper) traceRequest(method, pattern string, fn APIHandler) http.HandlerFunc { - f := HandlerFunc(fn) - if r.enableTracing { - pattern = sanitizePattern(pattern) - if r.tracingPrefix != "" { - pattern = r.tracingPrefix + "." + pattern - } - - resourceName := strings.ToUpper(method) - if pattern != "" { - resourceName += "::" + pattern - } - - return func(w http.ResponseWriter, req *http.Request) { - tracing.TrackRequest(w, req, r.rootLogger, r.svcName, resourceName, f) - } - } - return f -} - -func sanitizePattern(pattern string) string { - pattern = strings.TrimPrefix(pattern, "/") - pattern = strings.ReplaceAll(pattern, "{", "") - pattern = strings.ReplaceAll(pattern, "}", "") - pattern = strings.ReplaceAll(pattern, "/", ".") - pattern = strings.TrimSuffix(pattern, ".") - return pattern -} diff --git a/router/router_test.go b/router/router_test.go deleted file mode 100644 index 1e6c4c4..0000000 --- a/router/router_test.go +++ /dev/null @@ -1,174 +0,0 @@ -package router - -import ( - "net/http" - "net/http/httptest" - "strconv" - "testing" - - "github.com/netlify/netlify-commons/testutil" - "github.com/netlify/netlify-commons/tracing" - "github.com/opentracing/opentracing-go" - "github.com/opentracing/opentracing-go/mocktracer" - "github.com/sirupsen/logrus" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" -) - -func TestCORS(t *testing.T) { - req, err := http.NewRequest("OPTIONS", "/", nil) - require.NoError(t, err) - req.Header.Set("Origin", "myexamplehost.com") - req.Header.Set("Access-Control-Request-Method", "GET") - req.Header.Set("Access-Control-Request-Headers", "Content-Type") - t.Run("enabled", func(t *testing.T) { - rsp := do(t, []Option{OptEnableCORS}, "", "/", nil, req) - assert.Equal(t, http.StatusOK, rsp.Code) - }) - t.Run("disabled", func(t *testing.T) { - rsp := do(t, nil, "", "/", nil, req) - assert.Equal(t, http.StatusMethodNotAllowed, rsp.Code) - }) -} - -func TestCallthrough(t *testing.T) { - req, err := http.NewRequest(http.MethodGet, "/", nil) - require.NoError(t, err) - - var callCount int - handler := func(w http.ResponseWriter, r *http.Request) error { - callCount++ - return BadRequestError("") - } - rsp := do(t, nil, "", "/", handler, req) - assert.Equal(t, http.StatusBadRequest, rsp.Code) - assert.Equal(t, 1, callCount) -} - -func TestHealthEndpoint(t *testing.T) { - req, err := http.NewRequest(http.MethodGet, "/health", nil) - require.NoError(t, err) - - scenarios := map[string]struct { - opts []Option - code int - }{ - "disabled": {[]Option{OptHealthCheck("", nil)}, http.StatusNotFound}, - "default": {[]Option{OptHealthCheck("/health", nil)}, http.StatusOK}, - "custom": {[]Option{OptHealthCheck( - "/health", - func(_ http.ResponseWriter, r *http.Request) error { - return UnauthorizedError("") - })}, - http.StatusUnauthorized}, - } - - for name, scene := range scenarios { - t.Run(name, func(t *testing.T) { - rsp := do(t, scene.opts, "", "/", nil, req) - assert.Equal(t, scene.code, rsp.Code) - }) - } -} - -func TestTracing(t *testing.T) { - og := opentracing.GlobalTracer() - mt := mocktracer.New() - opentracing.SetGlobalTracer(mt) - defer func() { - opentracing.SetGlobalTracer(og) - }() - - noop := func(w http.ResponseWriter, r *http.Request) error { - assert.NotNil(t, tracing.GetTracer(r)) - w.WriteHeader(http.StatusOK) - return nil - } - - tl, logHook := testutil.TestLogger(t) - r := New(tl, OptEnableTracing("some-service")) - - r.Method(http.MethodPatch, "/patch", noop) - r.Delete("/abc/{def}", noop) - r.Get("/abc/{def}", noop) - r.Get("/", noop) - r.Post("/def/ghi", noop) - r.Put("/asdf/", noop) - r.Route("/sub", func(r Router) { - r.Get("/path", noop) - }) - - scenes := map[string]struct { - method, path, resourceName string - }{ - "get": {http.MethodGet, "/abc/def", "GET::abc.def"}, - "delete": {http.MethodDelete, "/abc/hfj", "DELETE::abc.def"}, - "post": {http.MethodPost, "/def/ghi", "POST::def.ghi"}, - "put": {http.MethodPut, "/asdf/", "PUT::asdf"}, - "patch": {http.MethodPatch, "/patch", "PATCH::patch"}, - "subroute": {http.MethodGet, "/sub/path", "GET::sub.path"}, - "single_slash": {http.MethodGet, "/", "GET"}, - } - - for name, scene := range scenes { - t.Run(name, func(t *testing.T) { - mt.Reset() - logHook.Reset() - - rec := httptest.NewRecorder() - r.ServeHTTP(rec, httptest.NewRequest(scene.method, scene.path, nil)) - assert.Equal(t, http.StatusOK, rec.Code) - - spans := mt.FinishedSpans() - if assert.Equal(t, 1, len(spans)) { - assert.Equal(t, "some-service", spans[0].Tag(ext.ServiceName)) - assert.Equal(t, scene.resourceName, spans[0].Tag(ext.ResourceName)) - assert.Equal(t, strconv.Itoa(http.StatusOK), spans[0].Tag(ext.HTTPCode)) - } - // should be a starting and finished request for each request - assert.Len(t, logHook.AllEntries(), 2) - }) - } -} - -func TestVersionHeader(t *testing.T) { - scenes := map[string]struct { - version string - expected string - header string - svc string - }{ - "custom": {version: "123", expected: "123", header: "x-nf-something-version", svc: "something"}, - "default": {version: "", expected: "unknown", header: "x-nf-something-version", svc: "something"}, - } - req, err := http.NewRequest(http.MethodGet, "/", nil) - require.NoError(t, err) - - for name, scene := range scenes { - t.Run(name, func(t *testing.T) { - opts := []Option{OptVersionHeader(scene.svc, scene.version)} - rsp := do(t, opts, scene.svc, "/", nil, req) - assert.Equal(t, scene.expected, rsp.Header().Get(scene.header), t.Name()) - }) - } -} - -func do(t *testing.T, opts []Option, svcName, path string, handler APIHandler, req *http.Request) *httptest.ResponseRecorder { - if opts == nil { - opts = []Option{} - } - r := New(logrus.WithField("test", t.Name()), opts...) - - if handler == nil { - handler = func(w http.ResponseWriter, r *http.Request) error { - return nil - } - } - if path != "" { - r.Get(path, handler) - } - rec := httptest.NewRecorder() - r.ServeHTTP(rec, req) - return rec -} diff --git a/server/options.go b/server/options.go deleted file mode 100644 index 32bb155..0000000 --- a/server/options.go +++ /dev/null @@ -1,44 +0,0 @@ -package server - -import ( - "crypto/tls" - "fmt" - "net/http" - "time" -) - -// Opt will allow modification of the http server -type Opt func(s *http.Server) - -// WithWriteTimeout will override the server's write timeout -func WithWriteTimeout(dur time.Duration) Opt { - return func(s *http.Server) { - s.WriteTimeout = dur - } -} - -// WithReadTimeout will override the server's read timeout -func WithReadTimeout(dur time.Duration) Opt { - return func(s *http.Server) { - s.ReadTimeout = dur - } -} - -// WithTLS will use the provided TLS configuration -func WithTLS(cfg *tls.Config) Opt { - return func(s *http.Server) { - s.TLSConfig = cfg - } -} - -// WithAddress will set the address field on the server -func WithAddress(addr string) Opt { - return func(s *http.Server) { - s.Addr = addr - } -} - -// WithHostAndPort will use them in the form host:port as the address field on the server -func WithHostAndPort(host string, port int) Opt { - return WithAddress(fmt.Sprintf("%s:%d", host, port)) -} diff --git a/server/server.go b/server/server.go deleted file mode 100644 index 225d62a..0000000 --- a/server/server.go +++ /dev/null @@ -1,226 +0,0 @@ -package server - -import ( - "context" - "net/http" - "net/http/httptest" - "os" - "os/signal" - "sync" - "syscall" - "time" - - "github.com/netlify/netlify-commons/nconf" - "github.com/netlify/netlify-commons/router" - "github.com/pkg/errors" - "github.com/sirupsen/logrus" -) - -const ( - defaultPort = 9090 - defaultHealthPath = "/health" -) - -// Server handles the setup and shutdown of the http server -// for an API -type Server struct { - log logrus.FieldLogger - svr *http.Server - api APIDefinition - done chan (bool) - doneOnce sync.Once -} - -type RouterConfig struct { - DisableTracing bool -} - -type Config struct { - HealthPath string `split_words:"true"` - Port int - Host string - TLS nconf.TLSConfig - Router RouterConfig -} - -// APIDefinition is used to control lifecycle of the API -type APIDefinition interface { - Start(r router.Router) error - Stop() - Info() APIInfo -} - -// APIInfo outlines the basic service information needed -type APIInfo struct { - Name string `json:"name"` - Version string `json:"version"` - Env string `json:"env"` -} - -// HealthChecker is used to run a custom health check -// Implement it on your API if you want it to be checked -// when the healthcheck is called -type HealthChecker interface { - Healthy(w http.ResponseWriter, r *http.Request) error -} - -// NewOpts will create the server with many defaults. You can use the opts to override them. -// the one major default you can't change by this is the health path. This is set to /health -// and be enabled. -func NewOpts(log logrus.FieldLogger, api APIDefinition, opts ...Opt) (*Server, error) { - defaultOpts := []Opt{ - WithHostAndPort("", defaultPort), - } - - return buildServer(log, api, append(defaultOpts, opts...), Config{ - HealthPath: defaultHealthPath, - }) -} - -// New will build a server with the defaults in place -func New(log logrus.FieldLogger, config Config, api APIDefinition) (*Server, error) { - opts := []Opt{ - WithHostAndPort(config.Host, config.Port), - } - - if config.TLS.Enabled { - tcfg, err := config.TLS.TLSConfig() - if err != nil { - return nil, errors.Wrap(err, "Failed to build TLS config") - } - log.Info("TLS enabled") - opts = append(opts, WithTLS(tcfg)) - } - - return buildServer(log, api, opts, config) -} - -func (s *Server) Shutdown(to time.Duration) error { - ctx, cancel := context.WithTimeout(context.Background(), to) - defer cancel() - defer func() { - s.doneOnce.Do(func() { - close(s.done) - }) - }() - - if err := s.svr.Shutdown(ctx); err != nil && err != http.ErrServerClosed { - return err - } - return nil -} - -func (s *Server) ListenAndServe() error { - go s.waitForShutdown() - - s.log.Infof("Starting server at %s", s.svr.Addr) - var err error - if s.svr.TLSConfig != nil { - // this is already setup in the New, empties are ok here - err = s.svr.ListenAndServeTLS("", "") - } else { - err = s.svr.ListenAndServe() - } - // Now that server is no longer listening - s.log.Info("Listener shutdown, waiting for connections to drain") - - // Wait until Shutdown returns - <-s.done - - s.log.Info("Connections are drained, shutting down API") - - s.api.Stop() - - s.log.Debug("Completed shutting down the underlying API") - - if err == http.ErrServerClosed { - return nil - } - return err -} - -func (s *Server) TestServer() *httptest.Server { - return httptest.NewServer(s.svr.Handler) -} - -func (s *Server) waitForShutdown() { - sigs := make(chan os.Signal, 1) - signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) - - s.log.Debug("Waiting for the shutdown signal") - sig := <-sigs - s.log.Infof("Received signal '%s', shutting down", sig) - if err := s.Shutdown(30 * time.Second); err != nil { - s.log.WithError(err).Warn("Failed to shutdown the server in time") - } -} - -type apiFunc struct { - start func(router.Router) error - stop func() - info APIInfo -} - -func (a apiFunc) Start(r router.Router) error { - return a.start(r) -} - -func (a apiFunc) Stop() { - a.stop() -} -func (a apiFunc) Info() APIInfo { - return a.info -} - -func APIFunc(start func(router.Router) error, stop func(), info APIInfo) APIDefinition { - return apiFunc{ - start: start, - stop: stop, - info: info, - } -} - -func buildRouter(log logrus.FieldLogger, api APIDefinition, config Config) router.Router { - var healthHandler router.APIHandler - if checker, ok := api.(HealthChecker); ok { - healthHandler = checker.Healthy - } - - opts := []router.Option{ - router.OptHealthCheck(config.HealthPath, healthHandler), - router.OptVersionHeader(api.Info().Name, api.Info().Version), - router.OptRecoverer(), - } - - if !config.Router.DisableTracing { - opts = append(opts, router.OptEnableTracing(api.Info().Name)) - } - - r := router.New( - log, - opts..., - ) - - return r -} - -func buildServer(log logrus.FieldLogger, api APIDefinition, opts []Opt, config Config) (*Server, error) { - r := buildRouter(log, api, config) - - if err := api.Start(r); err != nil { - return nil, errors.Wrap(err, "Failed to start API") - } - - svr := new(http.Server) - for _, o := range opts { - o(svr) - } - svr.Handler = r - s := Server{ - log: log.WithField("component", "server"), - svr: svr, - api: api, - done: make(chan bool), - } - return &s, nil -} diff --git a/server/server_test.go b/server/server_test.go deleted file mode 100644 index a258617..0000000 --- a/server/server_test.go +++ /dev/null @@ -1,159 +0,0 @@ -package server - -import ( - "net/http" - "net/http/httptest" - "os" - "strings" - "testing" - - "github.com/netlify/netlify-commons/router" - "github.com/sirupsen/logrus" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func init() { - if ll := os.Getenv("LOG_LEVEL"); strings.ToLower(ll) == "debug" { - logrus.SetLevel(logrus.DebugLevel) - } -} - -func TestServerHealth(t *testing.T) { - apiDef := APIFunc( - func(r router.Router) error { - r.Get("/", func(w http.ResponseWriter, r *http.Request) error { - return nil - }) - return nil - }, - func() { - }, - APIInfo{ - Name: t.Name(), - Version: "", - }, - ) - - cfg := testConfig() - svr, err := New(tl(t), cfg, apiDef) - require.NoError(t, err) - - testSvr := httptest.NewServer(svr.svr.Handler) - defer testSvr.Close() - - rsp, err := http.Get(testSvr.URL + cfg.HealthPath) - require.NoError(t, err) - assert.Equal(t, http.StatusOK, rsp.StatusCode) -} - -func TestServerVersioning(t *testing.T) { - apiDef := APIFunc( - func(r router.Router) error { - r.Get("/", func(w http.ResponseWriter, r *http.Request) error { - return nil - }) - return nil - }, - func() { - }, - APIInfo{ - Name: "testing", - Version: "", - }, - ) - cfg := testConfig() - t.Run("with-no-version", func(t *testing.T) { - svr, err := New(tl(t), cfg, apiDef) - require.NoError(t, err) - testSvr := httptest.NewServer(svr.svr.Handler) - defer testSvr.Close() - rsp, err := http.Get(testSvr.URL + cfg.HealthPath) - require.NoError(t, err) - assert.Equal(t, http.StatusOK, rsp.StatusCode) - assert.Equal(t, "unknown", rsp.Header.Get("X-Nf-Testing-Version")) - }) - - apiDef = APIFunc( - func(r router.Router) error { - r.Get("/", func(w http.ResponseWriter, r *http.Request) error { - return nil - }) - return nil - }, - func() { - }, - APIInfo{ - Name: "testing", - Version: "123", - }, - ) - - t.Run("with-version", func(t *testing.T) { - svr, err := New(tl(t), cfg, apiDef) - require.NoError(t, err) - testSvr := httptest.NewServer(svr.svr.Handler) - defer testSvr.Close() - rsp, err := http.Get(testSvr.URL + cfg.HealthPath) - require.NoError(t, err) - assert.Equal(t, http.StatusOK, rsp.StatusCode) - assert.Equal(t, "123", rsp.Header.Get("X-Nf-testing-version")) - }) - -} - -type testAPICustomHealth struct{} - -func (a *testAPICustomHealth) Start(r router.Router) error { - r.Get("/", func(w http.ResponseWriter, r *http.Request) error { - return nil - }) - return nil -} - -func (a *testAPICustomHealth) Stop() {} - -func (a *testAPICustomHealth) Healthy(w http.ResponseWriter, r *http.Request) error { - return router.InternalServerError("healthcheck failed") -} -func (a *testAPICustomHealth) Info() APIInfo { - return APIInfo{"testing", "", "testing"} -} - -func TestServerCustomHealth(t *testing.T) { - apiDef := new(testAPICustomHealth) - - cfg := testConfig() - svr, err := New(tl(t), cfg, apiDef) - require.NoError(t, err) - - testSvr := httptest.NewServer(svr.svr.Handler) - defer testSvr.Close() - - rsp, err := http.Get(testSvr.URL + cfg.HealthPath) - require.NoError(t, err) - assert.Equal(t, http.StatusInternalServerError, rsp.StatusCode) -} - -func tl(t *testing.T) *logrus.Entry { - return logrus.WithField("test", t.Name()) -} - -func testConfig() Config { - return Config{ - HealthPath: "/health", - Port: 9090, - } -} - -func TestServerAddr(t *testing.T) { - apiDef := new(testAPICustomHealth) - cfg := testConfig() - svr, err := New(tl(t), cfg, apiDef) - require.NoError(t, err) - require.Equal(t, svr.svr.Addr, ":9090") - cfg.Host = "127.0.0.1" - svrWithHost, err := New(tl(t), cfg, apiDef) - require.NoError(t, err) - require.Equal(t, svrWithHost.svr.Addr, "127.0.0.1:9090") -} diff --git a/testutil/logger.go b/testutil/logger.go deleted file mode 100644 index 7a353f0..0000000 --- a/testutil/logger.go +++ /dev/null @@ -1,65 +0,0 @@ -package testutil - -import ( - "os" - "testing" - - "github.com/sirupsen/logrus" - "github.com/sirupsen/logrus/hooks/test" -) - -// Opt is a function that will modify the logger used -type Opt func(l *logrus.Logger) - -// TL will build and return a test logger -func TL(t *testing.T, opts ...Opt) *logrus.Entry { - l, _ := TestLogger(t, opts...) - return l -} - -// TestLogger will build a logger that is useful for debugging -// it respects levels configured by the 'LOG_LEVEL' env var. -// It takes opt functions to modify the logger used -func TestLogger(t *testing.T, opts ...Opt) (*logrus.Entry, *test.Hook) { - l := logrus.New() - l.SetOutput(testLogWrapper{t}) - hook := test.NewLocal(l) - l.SetReportCaller(true) - if ll := os.Getenv("LOG_LEVEL"); ll != "" { - level, err := logrus.ParseLevel(ll) - if err != nil { - t.Logf("Error parsing the log level env var (%s), defaulting to info", ll) - level = logrus.InfoLevel - } - l.SetLevel(level) - } - - for _, o := range opts { - o(l) - } - - return l.WithField("test", t.Name()), hook -} - -type testLogWrapper struct { - t *testing.T -} - -func (w testLogWrapper) Write(p []byte) (n int, err error) { - w.t.Log(string(p)) - return len(p), nil -} - -// OptSetLevel will override the env var used to configre the logger -func OptSetLevel(lvl logrus.Level) Opt { - return func(l *logrus.Logger) { - l.SetLevel(lvl) - } -} - -// OptReportCaller will override the reporting of the calling function info -func OptReportCaller(b bool) Opt { - return func(l *logrus.Logger) { - l.SetReportCaller(b) - } -} diff --git a/testutil/logger_test.go b/testutil/logger_test.go deleted file mode 100644 index d9bbe71..0000000 --- a/testutil/logger_test.go +++ /dev/null @@ -1,44 +0,0 @@ -package testutil - -import ( - "os" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestLoggerWrapper(t *testing.T) { - tl, h := TestLogger(t) - tl.Info("this is a test") - assert.Len(t, h.Entries, 1) - assert.Equal(t, "this is a test", h.Entries[0].Message) - assert.Equal(t, "TestLoggerWrapper", h.Entries[0].Data["test"]) -} - -func TestLogWrapperLevel(t *testing.T) { - testCases := []struct { - desc string - expected []string - }{ - {desc: "DEBUG", expected: []string{"debug", "info", "warn", "error"}}, - {desc: "NONSENSE", expected: []string{"info", "warn", "error"}}, - {desc: "INFO", expected: []string{"info", "warn", "error"}}, - {desc: "ERROR", expected: []string{"error"}}, - } - for _, tC := range testCases { - t.Run(tC.desc, func(t *testing.T) { - os.Setenv("LOG_LEVEL", tC.desc) - tl, h := TestLogger(t) - tl.Debug("debug") - tl.Info("info") - tl.Warn("warn") - tl.Error("error") - - logged := []string{} - for _, entry := range h.Entries { - logged = append(logged, entry.Message) - } - assert.Equal(t, tC.expected, logged) - }) - } -} diff --git a/tracing/config.go b/tracing/config.go deleted file mode 100644 index 7ae54d1..0000000 --- a/tracing/config.go +++ /dev/null @@ -1,61 +0,0 @@ -package tracing - -import ( - "fmt" - "strings" - - "github.com/opentracing/opentracing-go" - "github.com/sirupsen/logrus" - "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/opentracer" - "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" -) - -const ( - HeaderNFDebugLogging = "X-NF-Debug-Logging" - HeaderRequestUUID = "X-BB-CLIENT-REQUEST-UUID" -) - -type Config struct { - Enabled bool `default:"false"` - Host string `default:"localhost"` - Port string `default:"8126"` - Tags map[string]string - EnableDebug bool `default:"false" split_words:"true" mapstructure:"enable_debug" json:"enable_debug" yaml:"enable_debug"` -} - -func Configure(tc *Config, log logrus.FieldLogger, svcName string) { - var t opentracing.Tracer = opentracing.NoopTracer{} - if tc.Enabled { - tracerAddr := fmt.Sprintf("%s:%s", tc.Host, tc.Port) - tracerOps := []tracer.StartOption{ - tracer.WithService(svcName), - tracer.WithAgentAddr(tracerAddr), - tracer.WithDebugMode(tc.EnableDebug), - tracer.WithLogger(debugLogger{log.WithField("component", "opentracing")}), - } - - var serviceTagSet bool - for k, v := range tc.Tags { - if strings.ToLower(k) == "service" { - serviceTagSet = true - } - - tracerOps = append(tracerOps, tracer.WithGlobalTag(k, v)) - } - - if !serviceTagSet { - tracerOps = append(tracerOps, tracer.WithGlobalTag("service", svcName)) - } - - t = opentracer.New(tracerOps...) - } - opentracing.SetGlobalTracer(t) -} - -type debugLogger struct { - log logrus.FieldLogger -} - -func (l debugLogger) Log(msg string) { - l.log.Debug(msg) -} diff --git a/tracing/context.go b/tracing/context.go deleted file mode 100644 index 5f8d3d5..0000000 --- a/tracing/context.go +++ /dev/null @@ -1,50 +0,0 @@ -package tracing - -import ( - "context" - "net/http" - - uuid "github.com/satori/go.uuid" -) - -type contextKey string - -const tracerKey = contextKey("nf-tracer-key") - -func WrapWithTracer(r *http.Request, rt *RequestTracer) *http.Request { - return r.WithContext(context.WithValue(r.Context(), tracerKey, rt)) -} - -func GetFromContext(ctx context.Context) *RequestTracer { - val := ctx.Value(tracerKey) - if val == nil { - return nil - } - entry, ok := val.(*RequestTracer) - if ok { - return entry - } - return nil -} - -func GetTracer(r *http.Request) *RequestTracer { - return GetFromContext(r.Context()) -} - -func GetRequestID(r *http.Request) string { - if id := GetRequestIDFromContext(r.Context()); id != "" { - return id - } - if rid := r.Header.Get(HeaderRequestUUID); rid != "" { - return rid - } - return uuid.NewV4().String() -} - -func GetRequestIDFromContext(ctx context.Context) string { - tr := GetFromContext(ctx) - if tr == nil { - return "" - } - return tr.RequestID -} diff --git a/tracing/logging.go b/tracing/logging.go deleted file mode 100644 index 7affddd..0000000 --- a/tracing/logging.go +++ /dev/null @@ -1,80 +0,0 @@ -package tracing - -import ( - "context" - "net/http" - - uuid "github.com/satori/go.uuid" - "github.com/sirupsen/logrus" -) - -func RequestLogger(r *http.Request, log logrus.FieldLogger) (logrus.FieldLogger, string) { - if r.Header.Get(HeaderNFDebugLogging) != "" { - logger := logrus.New() - logger.SetLevel(logrus.DebugLevel) - - if entry, ok := log.(*logrus.Entry); ok { - log = logger.WithFields(entry.Data) - logger.Hooks = entry.Logger.Hooks - } - } - - reqID := r.Header.Get(HeaderRequestUUID) - if reqID == "" { - reqID = uuid.NewV4().String() - r.Header.Set(HeaderRequestUUID, reqID) - } - - log = log.WithFields(logrus.Fields{ - "request_id": reqID, - }) - - return log, reqID -} - -func GetLoggerFromContext(ctx context.Context) logrus.FieldLogger { - entry := GetFromContext(ctx) - if entry == nil { - return logrus.NewEntry(logrus.StandardLogger()) - } - return entry.FieldLogger -} - -func GetLogger(r *http.Request) logrus.FieldLogger { - return GetLoggerFromContext(r.Context()) -} - -// SetLogField will add the field to this log line and every one following -func SetLogField(r *http.Request, key string, value interface{}) logrus.FieldLogger { - entry := GetTracer(r) - if entry == nil { - return logrus.StandardLogger().WithField(key, value) - } - return entry.SetLogField(key, value) -} - -// SetLogFields will add the fields to this log line and every one following -func SetLogFields(r *http.Request, fields logrus.Fields) logrus.FieldLogger { - entry := GetTracer(r) - if entry == nil { - return logrus.StandardLogger().WithFields(fields) - } - - return entry.SetLogFields(fields) -} - -// SetFinalField will add a field to the canonical line created at in Finish. It will add -// it to this line, but not every log line in between -func SetFinalField(r *http.Request, key string, value interface{}) logrus.FieldLogger { - return SetFinalFieldWithContext(r.Context(), key, value) -} - -// SetFinalFieldWithContext will add a field to the canonical line created at in Finish. It will add -// it to this line, but not every log line in between -func SetFinalFieldWithContext(ctx context.Context, key string, value interface{}) logrus.FieldLogger { - entry := GetFromContext(ctx) - if entry == nil { - return logrus.StandardLogger().WithField(key, value) - } - return entry.SetFinalField(key, value) -} diff --git a/tracing/logging_test.go b/tracing/logging_test.go deleted file mode 100644 index 81a3361..0000000 --- a/tracing/logging_test.go +++ /dev/null @@ -1,43 +0,0 @@ -package tracing - -import ( - "context" - "github.com/sirupsen/logrus" - "github.com/stretchr/testify/assert" - "net/http/httptest" - "testing" -) - -func TestGetLoggerFromContext_ContextContainsValue(t *testing.T) { - tracer := new(RequestTracer) - tracer.FieldLogger = logrus.New() - ctx := context.WithValue(context.Background(), tracerKey, tracer) - - assert.Same(t, tracer.FieldLogger, GetLoggerFromContext(ctx)) -} - -func TestGetLoggerFromContext_ContextDoesNotContainValue(t *testing.T) { - l := GetLoggerFromContext(context.Background()) - assert.NotNil(t, l) - assert.Implements(t, (*logrus.FieldLogger)(nil), l) -} - -func TestGetLogger_ContextContainsValue(t *testing.T) { - tracer := new(RequestTracer) - tracer.FieldLogger = logrus.New() - ctx := context.WithValue(context.Background(), tracerKey, tracer) - - r := httptest.NewRequest("", "/", nil) - r = r.WithContext(ctx) - - assert.Same(t, tracer.FieldLogger, GetLogger(r)) -} - -func TestGetLogger_ContextDoesNotContainValue(t *testing.T) { - r := httptest.NewRequest("", "/", nil) - l := GetLogger(r) - - assert.NotNil(t, l) - assert.Implements(t, (*logrus.FieldLogger)(nil), l) -} - diff --git a/tracing/opentracing.go b/tracing/opentracing.go deleted file mode 100644 index 2b78bbd..0000000 --- a/tracing/opentracing.go +++ /dev/null @@ -1,24 +0,0 @@ -package tracing - -import ( - "fmt" - - "github.com/opentracing/opentracing-go" - otlog "github.com/opentracing/opentracing-go/log" - "github.com/pkg/errors" -) - -type stackTracer interface { - StackTrace() errors.StackTrace -} - -func LogErrorToSpan(span opentracing.Span, err error) { - if err == nil || span == nil { - return - } - - span.LogFields(otlog.String("event", "error"), otlog.Error(err)) - if sterr, ok := err.(stackTracer); ok { - span.LogFields(otlog.String("stack", fmt.Sprintf("%+v", sterr.StackTrace()))) - } -} diff --git a/tracing/req_tracer.go b/tracing/req_tracer.go deleted file mode 100644 index 88fa692..0000000 --- a/tracing/req_tracer.go +++ /dev/null @@ -1,100 +0,0 @@ -package tracing - -import ( - "net/http" - "strconv" - "time" - - opentracing "github.com/opentracing/opentracing-go" - "github.com/sirupsen/logrus" - "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" -) - -type RequestTracer struct { - *trackingWriter - logrus.FieldLogger - - RequestID string - finalFields map[string]interface{} - - remoteAddr string - method string - originalURL string - referrer string - span opentracing.Span - start time.Time -} - -func NewTracer(w http.ResponseWriter, r *http.Request, log logrus.FieldLogger, service, resource string) (http.ResponseWriter, *http.Request, *RequestTracer) { - var reqID string - log, reqID = RequestLogger(r, log) - - r, span := WrapWithSpan(r, reqID, service, resource) - trackWriter := &trackingWriter{ - writer: w, - log: log, - } - - rt := &RequestTracer{ - originalURL: r.URL.String(), - method: r.Method, - referrer: r.Referer(), - remoteAddr: r.RemoteAddr, - - RequestID: reqID, - span: span, - trackingWriter: trackWriter, - FieldLogger: log, - finalFields: make(map[string]interface{}), - } - r = WrapWithTracer(r, rt) - - return rt, r, rt -} - -func (rt *RequestTracer) Start() { - rt.start = time.Now() - rt.WithFields(logrus.Fields{ - "method": rt.method, - "remote_addr": rt.remoteAddr, - "referer": rt.referrer, - "url": rt.originalURL, - }).Info("Starting Request") -} - -func (rt *RequestTracer) Finish() { - dur := time.Since(rt.start) - - fields := logrus.Fields{} - for k, v := range rt.finalFields { - fields[k] = v - } - - fields["status_code"] = rt.trackingWriter.status - fields["rsp_bytes"] = rt.trackingWriter.rspBytes - fields["url"] = rt.originalURL - fields["method"] = rt.method - fields["dur"] = dur.String() - fields["dur_ns"] = dur.Nanoseconds() - - // Setting the status as an int doesn't propogate for use in datadog dashboards, - // so we convert to a string. - rt.span.SetTag(ext.HTTPCode, strconv.Itoa(rt.trackingWriter.status)) - rt.span.Finish() - rt.WithFields(fields).Info("Completed Request") -} - -func (rt *RequestTracer) SetLogField(key string, value interface{}) logrus.FieldLogger { - rt.FieldLogger = rt.FieldLogger.WithField(key, value) - return rt.FieldLogger -} - -func (rt *RequestTracer) SetLogFields(fields logrus.Fields) logrus.FieldLogger { - rt.FieldLogger = rt.FieldLogger.WithFields(fields) - return rt.FieldLogger -} - -func (rt *RequestTracer) SetFinalField(key string, value interface{}) logrus.FieldLogger { - rt.finalFields[key] = value - return rt.FieldLogger.WithField(key, value) -} diff --git a/tracing/req_tracer_test.go b/tracing/req_tracer_test.go deleted file mode 100644 index b2ab8ff..0000000 --- a/tracing/req_tracer_test.go +++ /dev/null @@ -1,89 +0,0 @@ -package tracing - -import ( - "net/http" - "net/http/httptest" - "strconv" - "testing" - - "github.com/opentracing/opentracing-go" - "github.com/opentracing/opentracing-go/mocktracer" - logtest "github.com/sirupsen/logrus/hooks/test" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" -) - -func TestTracerLogging(t *testing.T) { - rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "http://whatever.com/something", nil) - - log, hook := logtest.NewNullLogger() - - _, r, rt := NewTracer(rec, req, log, t.Name(), "some_resource") - - rt.Start() - e := hook.LastEntry() - assert.Equal(t, 5, len(e.Data)) - assert.NotEmpty(t, e.Data["request_id"]) - assert.NotEmpty(t, e.Data["remote_addr"]) - assert.Empty(t, e.Data["referrer"]) - assert.NotEmpty(t, e.Data["method"]) - assert.Equal(t, "http://whatever.com/something", e.Data["url"]) - - _ = SetLogField(r, "first", "second") - SetFinalField(r, "final", "line").Info("should have the final here") - e = hook.LastEntry() - assert.Equal(t, 3, len(e.Data)) - assert.NotEmpty(t, e.Data["request_id"]) - assert.Equal(t, "line", e.Data["final"]) - assert.Equal(t, "second", e.Data["first"]) - - rt.Info("Shouldn't have the final line") - e = hook.LastEntry() - assert.Equal(t, 2, len(e.Data)) - assert.NotEmpty(t, e.Data["request_id"]) - assert.Equal(t, "second", e.Data["first"]) - - rt.WriteHeader(http.StatusOK) - rt.Write([]byte{0, 1, 2, 3}) - rt.Finish() - e = hook.LastEntry() - - assert.Equal(t, 9, len(e.Data)) - - // the automatic fields - assert.NotEmpty(t, e.Data["dur"]) - assert.NotEmpty(t, e.Data["dur_ns"]) - assert.NotEmpty(t, e.Data["request_id"]) - assert.Equal(t, 4, e.Data["rsp_bytes"]) - assert.Equal(t, 200, e.Data["status_code"]) - assert.Equal(t, "http://whatever.com/something", e.Data["url"]) - assert.Equal(t, "GET", e.Data["method"]) - - // the value that we added above - assert.Equal(t, "second", e.Data["first"]) - assert.Equal(t, "line", e.Data["final"]) -} - -func TestTracerSpans(t *testing.T) { - rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "http://whatever.com/something", nil) - - log, _ := logtest.NewNullLogger() - - mt := mocktracer.New() - opentracing.SetGlobalTracer(mt) - _, _, rt := NewTracer(rec, req, log, t.Name(), "some_resource") - rt.Start() - rt.WriteHeader(http.StatusOK) - rt.Finish() - - require.Len(t, mt.FinishedSpans(), 1) - span := mt.FinishedSpans()[0] - assert.Equal(t, t.Name(), span.Tag(ext.ServiceName)) - assert.Equal(t, "some_resource", span.Tag(ext.ResourceName)) - assert.Equal(t, http.MethodGet, span.Tag(ext.HTTPMethod)) - assert.Equal(t, strconv.Itoa(http.StatusOK), span.Tag(ext.HTTPCode)) - assert.Equal(t, rt.RequestID, span.Tag("http.request_id")) -} diff --git a/tracing/tracer.go b/tracing/tracer.go deleted file mode 100644 index 3ef9ceb..0000000 --- a/tracing/tracer.go +++ /dev/null @@ -1,48 +0,0 @@ -package tracing - -import ( - "net/http" - "strconv" - - opentracing "github.com/opentracing/opentracing-go" - ext "github.com/opentracing/opentracing-go/ext" - "github.com/sirupsen/logrus" - ddtrace_ext "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" - "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/opentracer" -) - -func TrackRequest(w http.ResponseWriter, r *http.Request, log logrus.FieldLogger, service, resource string, next http.Handler) { - w, r, rt := NewTracer(w, r, log, service, resource) - rt.Start() - next.ServeHTTP(w, r) - rt.Finish() -} - -func WrapWithSpan(r *http.Request, reqID, service, resource string) (*http.Request, opentracing.Span) { - span := opentracing.SpanFromContext(r.Context()) - if span != nil { - return r, span - } - - clientContext, _ := opentracing.GlobalTracer().Extract(opentracing.HTTPHeaders, opentracing.HTTPHeadersCarrier(r.Header)) - span, ctx := opentracing.StartSpanFromContext(r.Context(), "http.handler", - ext.RPCServerOption(clientContext), - opentracer.ServiceName(service), - opentracer.ResourceName(resource), - opentracer.SpanType(ddtrace_ext.AppTypeWeb), - opentracing.Tag{Key: "http.content_length", Value: strconv.FormatInt(r.ContentLength, 10)}, - ) - - // datadog specific span.kind, normally "server" - ext.Component.Set(span, "net/http") - // "normal" is default request type until overridden - ext.HTTPMethod.Set(span, r.Method) - ext.HTTPUrl.Set(span, r.URL.String()) - scheme := "http" - if r.Header.Get("X-Forwarded-Proto") == "https" { - scheme = "https" - } - span.SetTag("http.base_url", scheme+"://"+r.Host) - span.SetTag("http.request_id", reqID) - return r.WithContext(ctx), span -} diff --git a/tracing/tracing_test.go b/tracing/tracing_test.go deleted file mode 100644 index 9c669e2..0000000 --- a/tracing/tracing_test.go +++ /dev/null @@ -1,37 +0,0 @@ -package tracing - -import ( - "context" - "github.com/stretchr/testify/assert" - "net/http/httptest" - "testing" -) - -func TestGetFromContext_ContextContainsValue(t *testing.T) { - tracer := new(RequestTracer) - ctx := context.WithValue(context.Background(), tracerKey, tracer) - - assert.Same(t, tracer, GetFromContext(ctx)) -} - -func TestGetFromContext_ContextDoesNotContainValue(t *testing.T) { - assert.Nil(t, GetFromContext(context.Background())) -} - -func TestGetTracer_ContextContainsValue(t *testing.T) { - tracer := new(RequestTracer) - ctx := context.WithValue(context.Background(), tracerKey, tracer) - - r := httptest.NewRequest("", "/", nil) - r = r.WithContext(ctx) - - assert.Same(t, tracer, GetTracer(r)) -} - -func TestGetTracer_ContextDoesNotContainValue(t *testing.T) { - r := httptest.NewRequest("", "/", nil) - - assert.Nil(t, GetTracer(r)) -} - - diff --git a/tracing/writer.go b/tracing/writer.go deleted file mode 100644 index 75ebe26..0000000 --- a/tracing/writer.go +++ /dev/null @@ -1,51 +0,0 @@ -package tracing - -import ( - "bufio" - "fmt" - "net" - "net/http" - - "github.com/sirupsen/logrus" -) - -var _ http.Hijacker = (*trackingWriter)(nil) - -type trackingWriter struct { - writer http.ResponseWriter - rspBytes int - written bool - status int - log logrus.FieldLogger -} - -func (w *trackingWriter) OriginalWriter() http.ResponseWriter { - return w.writer -} - -func (w *trackingWriter) Write(in []byte) (int, error) { - w.rspBytes += len(in) - return w.writer.Write(in) -} - -func (w *trackingWriter) WriteHeader(code int) { - if w.written { - w.log.Warnf("Attempted to write the header twice: %d first, %d second", w.status, code) - return - } - w.status = code - w.written = true - w.writer.WriteHeader(code) -} - -func (w *trackingWriter) Header() http.Header { - return w.writer.Header() -} - -func (w *trackingWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - hj, ok := w.writer.(http.Hijacker) - if !ok { - return nil, nil, fmt.Errorf("webserver doesn't support hijacking") - } - return hj.Hijack() -} diff --git a/util/atomic_bool.go b/util/atomic_bool.go deleted file mode 100644 index 7ee259c..0000000 --- a/util/atomic_bool.go +++ /dev/null @@ -1,41 +0,0 @@ -package util - -import "sync/atomic" - -const ( - falseValue int32 = 0 - trueValue int32 = 1 -) - -type AtomicBool interface { - Set(bool) bool - Get() bool -} - -type atomicBool struct { - value int32 -} - -func NewAtomicBool(val bool) AtomicBool { - a := &atomicBool{value: falseValue} - a.Set(val) - return a -} - -// Set will set the value to boolValue and will return the previous value -func (a *atomicBool) Set(boolValue bool) bool { - intValue := int32(falseValue) - if boolValue { - intValue = trueValue - } - return toTruthy(atomic.SwapInt32(&a.value, intValue)) -} - -// Get will return the current value -func (a *atomicBool) Get() bool { - return toTruthy(atomic.LoadInt32(&a.value)) -} - -func toTruthy(val int32) bool { - return val != falseValue -} diff --git a/util/atomic_bool_test.go b/util/atomic_bool_test.go deleted file mode 100644 index c9f322c..0000000 --- a/util/atomic_bool_test.go +++ /dev/null @@ -1,24 +0,0 @@ -package util - -import ( - "github.com/stretchr/testify/require" - "testing" -) - -func TestAtomicBool(t *testing.T) { - ab := NewAtomicBool(false) - require.False(t, ab.Get()) - - ab = NewAtomicBool(true) - require.True(t, ab.Get()) - - require.True(t, ab.Set(false)) - require.False(t, ab.Get()) - - require.False(t, ab.Set(false)) - require.False(t, ab.Set(false)) - require.False(t, ab.Get()) - - require.False(t, ab.Set(true)) - require.True(t, ab.Get()) -} diff --git a/util/duration.go b/util/duration.go deleted file mode 100644 index 476705a..0000000 --- a/util/duration.go +++ /dev/null @@ -1,67 +0,0 @@ -package util - -import ( - "encoding/json" - "errors" - "time" -) - -// Duration is a serializable version version of a time.Duration -// it supports setting in yaml & json via: -// - string: 10s -// - float32/64, int/32/64: 10 (nanoseconds) -type Duration struct { - time.Duration -} - -func (d Duration) MarshalYAML() (interface{}, error) { - return d.String(), nil -} - -func (d Duration) MarshalJSON() ([]byte, error) { - return json.Marshal(d.String()) -} - -func (d *Duration) UnmarshalYAML(unmarshal func(interface{}) error) error { - var v interface{} - if err := unmarshal(&v); err != nil { - return err - } - return d.setValue(v) -} - -func (d *Duration) UnmarshalJSON(b []byte) error { - var v interface{} - if err := json.Unmarshal(b, &v); err != nil { - return err - } - return d.setValue(v) -} - -func (d *Duration) UnmarshalText(text []byte) error { - return d.setValue(string(text)) -} - -func (d *Duration) setValue(v interface{}) error { - switch value := v.(type) { - case float64: - d.Duration = time.Duration(value) - case float32: - d.Duration = time.Duration(value) - case int: - d.Duration = time.Duration(value) - case int32: - d.Duration = time.Duration(value) - case int64: - d.Duration = time.Duration(value) - case string: - var err error - d.Duration, err = time.ParseDuration(value) - if err != nil { - return err - } - default: - return errors.New("invalid duration") - } - return nil -} diff --git a/util/duration_test.go b/util/duration_test.go deleted file mode 100644 index 233b081..0000000 --- a/util/duration_test.go +++ /dev/null @@ -1,66 +0,0 @@ -package util - -import ( - "encoding/json" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "gopkg.in/yaml.v3" -) - -func TestDuration_Unmarshal(t *testing.T) { - testCases := []struct { - name string - url string - expected time.Duration - errCheck require.ErrorAssertionFunc - unmarshal func([]byte, interface{}) error - }{ - {"empty-json", `{"d": ""}`, 0, require.Error, json.Unmarshal}, - {"invalid-json", `{"d": "no duration here"}`, 0, require.Error, json.Unmarshal}, - {"valid-json-string", `{"d": "1s"}`, time.Second, require.NoError, json.Unmarshal}, - {"valid-json-int", `{"d": 1000000000}`, time.Second, require.NoError, json.Unmarshal}, - {"valid-json-float", `{"d": 1000000000.0}`, time.Second, require.NoError, json.Unmarshal}, - - {"empty-yaml", `u: ""}`, 0, require.Error, yaml.Unmarshal}, - {"invalid-yaml", `u: "no duration here"}`, 0, require.Error, yaml.Unmarshal}, - {"valid-yaml-string", `{"d": "1s"}`, time.Second, require.NoError, yaml.Unmarshal}, - {"valid-yaml-int", `{"d": 1000000000}`, time.Second, require.NoError, yaml.Unmarshal}, - {"valid-yaml-float", `{"d": 1000000000.0}`, time.Second, require.NoError, yaml.Unmarshal}, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - var s struct { - D Duration `json:"d"` - } - tc.errCheck(t, tc.unmarshal([]byte(tc.url), &s)) - assert.Equal(t, tc.expected, s.D.Duration) - }) - } -} - -func TestDuration_Marshal(t *testing.T) { - testCases := []struct { - name string - marshal func(interface{}) ([]byte, error) - expected string - }{ - {"json", json.Marshal, `{"d":"1s"}`}, - {"yaml", yaml.Marshal, "d: 1s\n"}, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - u := struct { - D Duration `json:"d" yaml:"d"` - }{D: Duration{time.Second}} - - serialized, err := tc.marshal(u) - require.NoError(t, err) - assert.Equal(t, tc.expected, string(serialized)) - }) - } -} diff --git a/util/headers.go b/util/headers.go deleted file mode 100644 index 091f783..0000000 --- a/util/headers.go +++ /dev/null @@ -1,48 +0,0 @@ -package util - -import ( - "encoding/json" - "net/http" -) - -// Headers is a serializable version of http.Header it supports both yaml & json formats. -// Headers expects the yaml/json representation to be a map[string]string. -type Headers struct { - http.Header -} - -func (h Headers) MarshalYAML() (interface{}, error) { - return h.Header, nil -} - -func (h Headers) MarshalJSON() ([]byte, error) { - return json.Marshal(h.Header) -} - -func (h *Headers) UnmarshalYAML(unmarshal func(interface{}) error) error { - var headers map[string]string - if err := unmarshal(&headers); err != nil { - return err - } - - h.Header = http.Header{} - for k, v := range headers { - h.Add(k, v) - } - - return nil -} - -func (h *Headers) UnmarshalJSON(b []byte) error { - var headers map[string]string - if err := json.Unmarshal(b, &headers); err != nil { - return err - } - - h.Header = http.Header{} - for k, v := range headers { - h.Add(k, v) - } - - return nil -} diff --git a/util/headers_test.go b/util/headers_test.go deleted file mode 100644 index 556b45a..0000000 --- a/util/headers_test.go +++ /dev/null @@ -1,90 +0,0 @@ -package util - -import ( - "encoding/json" - "net/http" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "gopkg.in/yaml.v3" -) - -func TestHeaders_Unmarshal(t *testing.T) { - testCases := []struct { - name string - url string - expected http.Header - errCheck require.ErrorAssertionFunc - unmarshal func([]byte, interface{}) error - }{ - {"invalid-json", `{"h": "<>:hey>"}`, nil, require.Error, json.Unmarshal}, - {"null-json", `{"h": null}`, http.Header{}, require.NoError, json.Unmarshal}, - {"empty-json", `{"h": {}}`, http.Header{}, require.NoError, json.Unmarshal}, - { - "valid-json", - `{"h": {"X-NF-TEST-A": "aAa", "x-nf-test-b": "bBb"}}`, - http.Header{"X-Nf-Test-A": {"aAa"}, "X-Nf-Test-B": {"bBb"}}, - require.NoError, - json.Unmarshal, - }, - { - "duplicate-json", - `{"h": {"X-NF-TEST-A": "aAa", "X-Nf-Test-A": "bBb", "x-nf-test-a": "cCc"}}`, - http.Header{"X-Nf-Test-A": {"aAa", "bBb", "cCc"}}, - require.NoError, - json.Unmarshal, - }, - - {"invalid-yaml", `h: ""`, nil, require.Error, yaml.Unmarshal}, - {"null-yaml", `h: null`, http.Header{}, require.NoError, yaml.Unmarshal}, - {"empty-yaml", `h: {}`, http.Header{}, require.NoError, yaml.Unmarshal}, - { - "valid-yaml", - `h: {"X-NF-TEST-A": "aAa", "x-nf-test-b": "bBb"}`, - http.Header{"X-Nf-Test-A": {"aAa"}, "X-Nf-Test-B": {"bBb"}}, - require.NoError, - yaml.Unmarshal, - }, - { - "duplicate-yaml", - `h: {"X-NF-TEST-A": "aAa", "X-Nf-Test-A": "bBb", "x-nf-test-a": "cCc"}`, - http.Header{"X-Nf-Test-A": {"aAa", "bBb", "cCc"}}, - require.NoError, - yaml.Unmarshal, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - var s struct { - H Headers `json:"h"` - } - tc.errCheck(t, tc.unmarshal([]byte(tc.url), &s)) - assert.ObjectsAreEqualValues(tc.expected, s.H.Header) - }) - } -} - -func TestHeaders_Marshal(t *testing.T) { - testCases := []struct { - name string - marshal func(interface{}) ([]byte, error) - expected string - }{ - {"json", json.Marshal, `{"h":{"X-Nf-Test-A":["aAa","bBb"]}}`}, - {"yaml", yaml.Marshal, "h:\n X-Nf-Test-A:\n - aAa\n - bBb\n"}, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - h := struct { - H Headers `json:"h" yaml:"h"` - }{H: Headers{http.Header{"X-Nf-Test-A": {"aAa", "bBb"}}}} - - serialized, err := tc.marshal(h) - require.NoError(t, err) - assert.Equal(t, tc.expected, string(serialized)) - }) - } -} diff --git a/util/scheduled_executor.go b/util/scheduled_executor.go deleted file mode 100644 index 6466a0a..0000000 --- a/util/scheduled_executor.go +++ /dev/null @@ -1,67 +0,0 @@ -package util - -import ( - "sync" - "time" -) - -type ScheduledExecutor interface { - Start() - Stop() -} - -type scheduledExecutor struct { - period time.Duration - cb func() - isRunning AtomicBool - ticker *time.Ticker - done chan bool - wg sync.WaitGroup -} - -func NewScheduledExecutor(period time.Duration, cb func()) ScheduledExecutor { - return &scheduledExecutor{ - period: period, - cb: cb, - isRunning: NewAtomicBool(false), - wg: sync.WaitGroup{}, - } -} - -func (s *scheduledExecutor) Start() { - if s.isRunning.Set(true) { - return - } - - s.ticker = time.NewTicker(s.period) - s.done = make(chan bool) - s.wg.Add(1) - - go s.poll() -} - -func (s *scheduledExecutor) Stop() { - if !s.isRunning.Set(false) { - return - } - - s.ticker.Stop() - s.done <- true - s.wg.Wait() - - s.ticker = nil - s.done = nil -} - -func (s *scheduledExecutor) poll() { - defer s.wg.Done() - - for { - select { - case <-s.done: - return - case <-s.ticker.C: - s.cb() - } - } -} diff --git a/util/url.go b/util/url.go deleted file mode 100644 index 831d435..0000000 --- a/util/url.go +++ /dev/null @@ -1,58 +0,0 @@ -package util - -import ( - "encoding/json" - "fmt" - "net/url" - - "github.com/pkg/errors" -) - -// URL is a serializable version version of a url.URL -// it supports serialization in yaml and json -type URL struct { - *url.URL -} - -func (u URL) MarshalYAML() (interface{}, error) { - return u.String(), nil -} - -func (u URL) MarshalJSON() ([]byte, error) { - return json.Marshal(u.String()) -} - -func (u *URL) UnmarshalYAML(unmarshal func(interface{}) error) error { - var v interface{} - if err := unmarshal(&v); err != nil { - return err - } - return u.setValue(v) -} - -func (u *URL) UnmarshalJSON(b []byte) error { - var v interface{} - if err := json.Unmarshal(b, &v); err != nil { - return err - } - return u.setValue(v) -} - -func (u *URL) setValue(v interface{}) error { - switch value := v.(type) { - case string: - if value == "" { - return errors.New("empty string provided as url") - } - - parsed, err := url.Parse(value) - if err != nil { - return errors.Wrap(err, "invalid url provided") - } - - u.URL = parsed - default: - return errors.New(fmt.Sprintf("invalid type provided as url: %T", v)) - } - return nil -} diff --git a/util/url_test.go b/util/url_test.go deleted file mode 100644 index 530fd91..0000000 --- a/util/url_test.go +++ /dev/null @@ -1,62 +0,0 @@ -package util - -import ( - "encoding/json" - "net/url" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "gopkg.in/yaml.v3" -) - -func TestURL_Unmarshal(t *testing.T) { - testCases := []struct { - name string - url string - expected *url.URL - errCheck require.ErrorAssertionFunc - unmarshal func([]byte, interface{}) error - }{ - {"empty-json", `{"u": ""}`, nil, require.Error, json.Unmarshal}, - {"invalid-json", `{"u": "<>:hey>"}`, nil, require.Error, json.Unmarshal}, - {"valid-json", `{"u": "https://netlify.com"}`, &url.URL{Scheme: "https", Host: "netlify.com"}, require.NoError, json.Unmarshal}, - - {"empty-yaml", `u: ""}`, nil, require.Error, yaml.Unmarshal}, - {"invalid-yaml", `u: "<>:hey>"}`, nil, require.Error, yaml.Unmarshal}, - {"valid-yaml", `{u: https://netlify.com}`, &url.URL{Scheme: "https", Host: "netlify.com"}, require.NoError, yaml.Unmarshal}, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - var s struct { - U URL `json:"u"` - } - tc.errCheck(t, tc.unmarshal([]byte(tc.url), &s)) - assert.Equal(t, tc.expected, s.U.URL) - }) - } -} - -func TestURL_Marshal(t *testing.T) { - testCases := []struct { - name string - marshal func(interface{}) ([]byte, error) - expected string - }{ - {"json", json.Marshal, `{"u":"https://netlify.com"}`}, - {"yaml", yaml.Marshal, "u: https://netlify.com\n"}, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - u := struct { - U URL `json:"u" yaml:"u"` - }{U: URL{&url.URL{Scheme: "https", Host: "netlify.com"}}} - - serialized, err := tc.marshal(u) - require.NoError(t, err) - assert.Equal(t, tc.expected, string(serialized)) - }) - } -}