diff --git a/filter/BUILD.bazel b/filter/BUILD.bazel new file mode 100644 index 000000000..ddb400246 --- /dev/null +++ b/filter/BUILD.bazel @@ -0,0 +1,17 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "go_default_library", + srcs = [ + "filter.go", + "http.go", + ], + importpath = "github.com/reddit/baseplate.go/filter", + visibility = ["//visibility:public"], +) + +go_test( + name = "go_default_test", + srcs = ["http_test.go"], + embed = [":go_default_library"], +) diff --git a/filter/filter.go b/filter/filter.go new file mode 100644 index 000000000..2414ad51e --- /dev/null +++ b/filter/filter.go @@ -0,0 +1,30 @@ +package filter + +import "context" + +// Filter is a generic middleware type +type Filter interface { + Do(ctx context.Context, request interface{}, service Service) (response interface{}, err error) +} + +// Service is a generic client/server type +type Service interface { + Do(ctx context.Context, request interface{}) (response interface{}, err error) +} + +// ServiceWithFilters applies the filters to a service in a standard way. +func ServiceWithFilters(service Service, filters ...Filter) Service { + for i := len(filters) - 1; i >= 0; i-- { + service = &filteredService{filter: filters[i], service: service} + } + return service +} + +type filteredService struct { + filter Filter + service Service +} + +func (fs *filteredService) Do(ctx context.Context, request interface{}) (response interface{}, err error) { + return fs.filter.Do(ctx, request, fs.service) +} diff --git a/filter/http.go b/filter/http.go new file mode 100644 index 000000000..de889f6c6 --- /dev/null +++ b/filter/http.go @@ -0,0 +1,90 @@ +package filter + +import ( + "context" + "errors" + "io" + "net/http" + "net/url" + "strings" +) + +//HTTPClientWithFilters applies filter middleware to an http client. +func HTTPClientWithFilters(client *http.Client, filters ...Filter) *Client { + svc := HTTPClientAsService(client) + withFilters := ServiceWithFilters(svc, filters...) + return httpClientAdapter(withFilters) +} + +func httpClientAdapter(service Service) *Client { + return &Client{inner: service} +} + +// HTTPClientAsService represents an http.Client as a Service +func HTTPClientAsService(client *http.Client) Service { + return &httpClientService{client} +} + +type httpClientService struct { + client *http.Client +} + +// Client is duck-typed like http.Client, but internally implemented by a Service. +type Client struct { + inner Service +} + +func (svc *httpClientService) Do(ctx context.Context, request interface{}) (response interface{}, err error) { + httpRequest, ok := request.(*http.Request) + if !ok { + return nil, errors.New("not an http request") + } + // We copy to apply the appropriate context + if httpRequest.Context() != ctx { + httpRequest = httpRequest.Clone(ctx) + } + return svc.client.Do(httpRequest) +} + +// Do is a copy of http.Do +func (c *Client) Do(req *http.Request) (resp *http.Response, err error) { + r, err := c.inner.Do(req.Context(), req) + resp, ok := r.(*http.Response) + if !ok && err == nil { + return nil, errors.New("not an http response") + } + return +} + +// Get is a copy of http.Get +func (c *Client) Get(url string) (resp *http.Response, err error) { + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, err + } + return c.Do(req) +} + +// Head is a copy of http.Head +func (c *Client) Head(url string) (resp *http.Response, err error) { + req, err := http.NewRequest("HEAD", url, nil) + if err != nil { + return nil, err + } + return c.Do(req) +} + +// Post is a copy of http.Post +func (c *Client) Post(url, contentType string, body io.Reader) (resp *http.Response, err error) { + req, err := http.NewRequest("POST", url, body) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", contentType) + return c.Do(req) +} + +// PostForm is a copy of http.PostForm +func (c *Client) PostForm(url string, data url.Values) (resp *http.Response, err error) { + return c.Post(url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode())) +} diff --git a/filter/http_test.go b/filter/http_test.go new file mode 100644 index 000000000..9514b9102 --- /dev/null +++ b/filter/http_test.go @@ -0,0 +1,52 @@ +package filter + +import ( + "context" + "errors" + "net/http" + "testing" + "time" +) + +type helloFilter struct { + msg string +} + +func (f *helloFilter) Do(ctx context.Context, request interface{}, service Service) (rsp interface{}, err error) { + rsp, err = service.Do(ctx, request) + httpRsp, ok := rsp.(*http.Response) + if ok { + if err == nil { + httpRsp.Header.Add("hello", f.msg) + } + } else { + err = errors.New("not an http response") + } + return +} + +type slowFilter struct { + duration time.Duration +} + +func (f *slowFilter) Do(ctx context.Context, request interface{}, service Service) (rsp interface{}, err error) { + time.Sleep(f.duration) + return service.Do(ctx, request) +} + +func TestHttpClientWithSpecificFilter(t *testing.T) { + client := HTTPClientWithFilters(&http.Client{}, &helloFilter{msg: "world"}) + rsp, _ := client.Get("https://google.com/") + if rsp.Header.Get("hello") != "world" { + t.Errorf("didn't set response header") + } +} +func TestHttpClientWithGenericFilter(t *testing.T) { + sleepFor := 1 * time.Second + client := HTTPClientWithFilters(&http.Client{}, &slowFilter{duration: sleepFor}) + start := time.Now() + _, _ = client.Get("https://google.com/") + if time.Now().Sub(start) < sleepFor { + t.Error("Didn't sleep long enough") + } +}