diff --git a/nethttp/server.go b/nethttp/server.go index 322e88f..9eb8fb4 100644 --- a/nethttp/server.go +++ b/nethttp/server.go @@ -7,18 +7,19 @@ import ( "net/http" "net/url" - opentracing "github.com/opentracing/opentracing-go" + "github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go/ext" ) var responseSizeKey = "http.response_size" type mwOptions struct { - opNameFunc func(r *http.Request) string - spanFilter func(r *http.Request) bool - spanObserver func(span opentracing.Span, r *http.Request) - urlTagFunc func(u *url.URL) string - componentName string + opNameFunc func(r *http.Request) string + spanFilter func(r *http.Request) bool + startSpanOptions func(r *http.Request) []opentracing.StartSpanOption + spanObserver func(span opentracing.Span, r *http.Request) + urlTagFunc func(u *url.URL) string + componentName string } // MWOption controls the behavior of the Middleware. @@ -49,6 +50,14 @@ func MWSpanFilter(f func(r *http.Request) bool) MWOption { } } +// MWStartSpanOptions returns a MWOption that creates options for starting a span. +// Middleware first applies default StartSpanOptions, followed by the ones supplied by the user. +func MWStartSpanOptions(f func(r *http.Request) []opentracing.StartSpanOption) MWOption { + return func(options *mwOptions) { + options.startSpanOptions = f + } +} + // MWSpanObserver returns a MWOption that observe the span // for the server-side span. func MWSpanObserver(f func(span opentracing.Span, r *http.Request)) MWOption { @@ -105,8 +114,9 @@ func MiddlewareFunc(tr opentracing.Tracer, h http.HandlerFunc, options ...MWOpti opNameFunc: func(r *http.Request) string { return "HTTP " + r.Method }, - spanFilter: func(r *http.Request) bool { return true }, - spanObserver: func(span opentracing.Span, r *http.Request) {}, + spanFilter: func(r *http.Request) bool { return true }, + startSpanOptions: func(r *http.Request) []opentracing.StartSpanOption { return nil }, + spanObserver: func(span opentracing.Span, r *http.Request) {}, urlTagFunc: func(u *url.URL) string { return u.String() }, @@ -126,7 +136,8 @@ func MiddlewareFunc(tr opentracing.Tracer, h http.HandlerFunc, options ...MWOpti return } ctx, _ := tr.Extract(opentracing.HTTPHeaders, opentracing.HTTPHeadersCarrier(r.Header)) - sp := tr.StartSpan(opts.opNameFunc(r), ext.RPCServerOption(ctx)) + startSpanOptions := collectStartSpanOptions(ctx, r, opts) + sp := tr.StartSpan(opts.opNameFunc(r), startSpanOptions...) ext.HTTPMethod.Set(sp, r.Method) ext.HTTPUrl.Set(sp, opts.urlTagFunc(r.URL)) ext.Component.Set(sp, componentName) @@ -164,3 +175,13 @@ func MiddlewareFunc(tr opentracing.Tracer, h http.HandlerFunc, options ...MWOpti } return http.HandlerFunc(fn) } + +func collectStartSpanOptions(ctx opentracing.SpanContext, r *http.Request, opts mwOptions) []opentracing.StartSpanOption { + mwStartSpanOptions := opts.startSpanOptions(r) + + startSpanOptions := make([]opentracing.StartSpanOption, 0, len(mwStartSpanOptions)+1) + startSpanOptions = append(startSpanOptions, ext.RPCServerOption(ctx)) + startSpanOptions = append(startSpanOptions, mwStartSpanOptions...) + + return startSpanOptions +} diff --git a/nethttp/server_test.go b/nethttp/server_test.go index d72235b..5f1118b 100644 --- a/nethttp/server_test.go +++ b/nethttp/server_test.go @@ -182,6 +182,81 @@ func TestSpanFilterOption(t *testing.T) { } } +func TestStartSpanOptionsOption(t *testing.T) { + t.Parallel() + mux := http.NewServeMux() + mux.HandleFunc("/root", func(w http.ResponseWriter, r *http.Request) {}) + + const customTagForHTTPMethod = "custom_tag_for_http_method" + + mwOptions := []MWOption{MWStartSpanOptions(func(r *http.Request) []opentracing.StartSpanOption { + return []opentracing.StartSpanOption{ + opentracing.Tag{Key: customTagForHTTPMethod, Value: r.Method}, + } + })} + + tests := []struct { //nolint:govet + name string + httpMethod string + options []MWOption + expectMethodTag string + }{ + { + name: "without options", + httpMethod: http.MethodGet, + options: nil, + expectMethodTag: "", + }, + { + name: "with options and method = GET", + httpMethod: http.MethodGet, + options: mwOptions, + expectMethodTag: http.MethodGet, + }, + { + name: "with options and method = PATCH", + httpMethod: http.MethodPatch, + options: mwOptions, + expectMethodTag: http.MethodPatch, + }, + } + + for _, tt := range tests { + testCase := tt + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + tr := &mocktracer.MockTracer{} + mw := Middleware(tr, mux, testCase.options...) + srv := httptest.NewServer(mw) + defer srv.Close() + + req, err := http.NewRequestWithContext(context.Background(), testCase.httpMethod, srv.URL, nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + t.Fatalf("server returned error: %v", err) + } + defer resp.Body.Close() + + spans := tr.FinishedSpans() + if got, want := len(spans), 1; got != want { + t.Fatalf("got %d spans, expected %d", got, want) + } + + tag, ok := spans[0].Tags()[customTagForHTTPMethod].(string) + if !ok { + tag = "" + } + if got, want := tag, testCase.expectMethodTag; got != want { + t.Fatalf("got %s tag name, expected %s", got, want) + } + }) + } +} + func TestURLTagOption(t *testing.T) { t.Parallel() mux := http.NewServeMux()