Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 29 additions & 9 deletions nethttp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,19 @@ import (
"net/http"
"net/url"

opentracing "github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
)

var responseSizeKey = "http.response_size"

type mwOptions struct {
opNameFunc func(r *http.Request) string
spanFilter func(r *http.Request) bool
spanObserver func(span opentracing.Span, r *http.Request)
urlTagFunc func(u *url.URL) string
componentName string
opNameFunc func(r *http.Request) string
spanFilter func(r *http.Request) bool
startSpanOptions func(r *http.Request) []opentracing.StartSpanOption
spanObserver func(span opentracing.Span, r *http.Request)
urlTagFunc func(u *url.URL) string
componentName string
}

// MWOption controls the behavior of the Middleware.
Expand Down Expand Up @@ -49,6 +50,13 @@ func MWSpanFilter(f func(r *http.Request) bool) MWOption {
}
}

// MWStartSpanOptions returns a MWOption that creates options for starting a span.
func MWStartSpanOptions(f func(r *http.Request) []opentracing.StartSpanOption) MWOption {
return func(options *mwOptions) {
options.startSpanOptions = f
}
}

// MWSpanObserver returns a MWOption that observe the span
// for the server-side span.
func MWSpanObserver(f func(span opentracing.Span, r *http.Request)) MWOption {
Expand Down Expand Up @@ -105,8 +113,9 @@ func MiddlewareFunc(tr opentracing.Tracer, h http.HandlerFunc, options ...MWOpti
opNameFunc: func(r *http.Request) string {
return "HTTP " + r.Method
},
spanFilter: func(r *http.Request) bool { return true },
spanObserver: func(span opentracing.Span, r *http.Request) {},
spanFilter: func(r *http.Request) bool { return true },
startSpanOptions: func(r *http.Request) []opentracing.StartSpanOption { return nil },
spanObserver: func(span opentracing.Span, r *http.Request) {},
urlTagFunc: func(u *url.URL) string {
return u.String()
},
Expand All @@ -126,7 +135,8 @@ func MiddlewareFunc(tr opentracing.Tracer, h http.HandlerFunc, options ...MWOpti
return
}
ctx, _ := tr.Extract(opentracing.HTTPHeaders, opentracing.HTTPHeadersCarrier(r.Header))
sp := tr.StartSpan(opts.opNameFunc(r), ext.RPCServerOption(ctx))
startSpanOptions := collectStartSpanOptions(ctx, r, opts)
sp := tr.StartSpan(opts.opNameFunc(r), startSpanOptions...)
ext.HTTPMethod.Set(sp, r.Method)
ext.HTTPUrl.Set(sp, opts.urlTagFunc(r.URL))
ext.Component.Set(sp, componentName)
Expand Down Expand Up @@ -164,3 +174,13 @@ func MiddlewareFunc(tr opentracing.Tracer, h http.HandlerFunc, options ...MWOpti
}
return http.HandlerFunc(fn)
}

func collectStartSpanOptions(ctx opentracing.SpanContext, r *http.Request, opts mwOptions) []opentracing.StartSpanOption {
mwStartSpanOptions := opts.startSpanOptions(r)

startSpanOptions := make([]opentracing.StartSpanOption, 0, len(mwStartSpanOptions)+1)
startSpanOptions = append(startSpanOptions, ext.RPCServerOption(ctx))
startSpanOptions = append(startSpanOptions, mwStartSpanOptions...)

return startSpanOptions
}
75 changes: 75 additions & 0 deletions nethttp/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,81 @@ func TestSpanFilterOption(t *testing.T) {
}
}

func TestStartSpanOptionsOption(t *testing.T) {
t.Parallel()
mux := http.NewServeMux()
mux.HandleFunc("/root", func(w http.ResponseWriter, r *http.Request) {})

const customTagForHTTPMethod = "custom_tag_for_http_method"

mwOptions := []MWOption{MWStartSpanOptions(func(r *http.Request) []opentracing.StartSpanOption {
return []opentracing.StartSpanOption{
opentracing.Tag{Key: customTagForHTTPMethod, Value: r.Method},
}
})}

tests := []struct { //nolint:govet
name string
httpMethod string
options []MWOption
expectMethodTag string
}{
{
name: "without options",
httpMethod: http.MethodGet,
options: nil,
expectMethodTag: "<nil>",
},
{
name: "with options and method = GET",
httpMethod: http.MethodGet,
options: mwOptions,
expectMethodTag: http.MethodGet,
},
{
name: "with options and method = PATCH",
httpMethod: http.MethodPatch,
options: mwOptions,
expectMethodTag: http.MethodPatch,
},
}

for _, tt := range tests {
testCase := tt
t.Run(testCase.name, func(t *testing.T) {
t.Parallel()
tr := &mocktracer.MockTracer{}
mw := Middleware(tr, mux, testCase.options...)
srv := httptest.NewServer(mw)
defer srv.Close()

req, err := http.NewRequestWithContext(context.Background(), testCase.httpMethod, srv.URL, nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
}
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("server returned error: %v", err)
}
defer resp.Body.Close()

spans := tr.FinishedSpans()
if got, want := len(spans), 1; got != want {
t.Fatalf("got %d spans, expected %d", got, want)
}

tag, ok := spans[0].Tags()[customTagForHTTPMethod].(string)
if !ok {
tag = "<nil>"
}
if got, want := tag, testCase.expectMethodTag; got != want {
t.Fatalf("got %s tag name, expected %s", got, want)
}
})
}
}

func TestURLTagOption(t *testing.T) {
t.Parallel()
mux := http.NewServeMux()
Expand Down