diff --git a/endpoint/endpoint.go b/endpoint/endpoint.go index c31625b2a..2702ef2ad 100644 --- a/endpoint/endpoint.go +++ b/endpoint/endpoint.go @@ -10,6 +10,10 @@ import ( // It represents a single RPC method. type Endpoint func(ctx context.Context, request interface{}) (response interface{}, err error) +// Nop is an endpoint that does nothing and returns a nil error. +// Useful for tests. +func Nop(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } + // Middleware is a chainable behavior modifier for endpoints. type Middleware func(Endpoint) Endpoint diff --git a/examples/README.md b/examples/README.md index c005e3ed6..b43b69670 100644 --- a/examples/README.md +++ b/examples/README.md @@ -75,7 +75,7 @@ type uppercaseRequest struct { type uppercaseResponse struct { V string `json:"v"` - Err string `json:"err,omitempty"` // errors don't define JSON marshaling + Err string `json:"err,omitempty"` // errors don't JSON-marshal, so we use a string } type countRequest struct { @@ -98,6 +98,7 @@ type Endpoint func(ctx context.Context, request interface{}) (response interface An endpoint represents a single RPC. That is, a single method in our service interface. We'll write simple adapters to convert each of our service's methods into an endpoint. +Each adapter takes a StringService, and returns an endpoint that corresponds to one of the methods. ```go import ( @@ -281,9 +282,9 @@ Since our StringService is defined as an interface, we just need to make a new t which wraps an existing StringService, and performs the extra logging duties. ```go -type loggingMiddleware struct{ +type loggingMiddleware struct { logger log.Logger - StringService + next StringService } func (mw loggingMiddleware) Uppercase(s string) (output string, err error) { @@ -297,7 +298,7 @@ func (mw loggingMiddleware) Uppercase(s string) (output string, err error) { ) }(time.Now()) - output, err = mw.StringService.Uppercase(s) + output, err = mw.next.Uppercase(s) return } @@ -311,7 +312,7 @@ func (mw loggingMiddleware) Count(s string) (n int) { ) }(time.Now()) - n = mw.StringService.Count(s) + n = mw.next.Count(s) return } ``` @@ -329,9 +330,12 @@ import ( func main() { logger := log.NewLogfmtLogger(os.Stderr) - svc := stringService{} + var svc StringService + svc = stringsvc{} svc = loggingMiddleware{logger, svc} + // ... + uppercaseHandler := httptransport.NewServer( // ... makeUppercaseEndpoint(svc), @@ -364,7 +368,7 @@ type instrumentingMiddleware struct { requestCount metrics.Counter requestLatency metrics.TimeHistogram countResult metrics.Histogram - StringService + next StringService } func (mw instrumentingMiddleware) Uppercase(s string) (output string, err error) { @@ -375,7 +379,7 @@ func (mw instrumentingMiddleware) Uppercase(s string) (output string, err error) mw.requestLatency.With(methodField).With(errorField).Observe(time.Since(begin)) }(time.Now()) - output, err = mw.StringService.Uppercase(s) + output, err = mw.next.Uppercase(s) return } @@ -388,7 +392,7 @@ func (mw instrumentingMiddleware) Count(s string) (n int) { mw.countResult.Observe(int64(n)) }(time.Now()) - n = mw.StringService.Count(s) + n = mw.next.Count(s) return } ``` @@ -416,21 +420,12 @@ func main() { // ... }, []string{})) - svc := stringService{} + var svc StringService + svc = stringService{} svc = loggingMiddleware{logger, svc} svc = instrumentingMiddleware{requestCount, requestLatency, countResult, svc} - uppercaseHandler := httptransport.NewServer( - // ... - makeUppercaseEndpoint(svc), - // ... - ) - - countHandler := httptransport.NewServer( - // ... - makeCountEndpoint(svc), - // ... - ) + // ... http.Handle("/metrics", stdprometheus.Handler()) } @@ -467,17 +462,19 @@ Often, you need to call other services. **This is where Go kit shines**. We provide transport middlewares to solve many of the problems that come up. -Let's implement the proxying middleware as a ServiceMiddleware. -We'll only proxy one method, Uppercase. +Let's say that we want to have our string service call out to a _different_ string service + to satisfy the Uppercase method. +In effect, proxying the request to another service. +Let's implement the proxying middleware as a ServiceMiddleware, same as a logging or instrumenting middleware. ```go // proxymw implements StringService, forwarding Uppercase requests to the // provided endpoint, and serving all other (i.e. Count) requests via the -// embedded StringService. +// next StringService. type proxymw struct { - context.Context - StringService // Serve most requests via this embedded service... - UppercaseEndpoint endpoint.Endpoint // ...except Uppercase, which gets served by this endpoint + ctx context.Context + next StringService // Serve most requests via this service... + uppercase endpoint.Endpoint // ...except Uppercase, which gets served by this endpoint } ``` @@ -489,7 +486,7 @@ And to invoke the client endpoint, we just do some simple conversions. ```go func (mw proxymw) Uppercase(s string) (string, error) { - response, err := mw.UppercaseEndpoint(mw.Context, uppercaseRequest{S: s}) + response, err := mw.uppercase(mw.Context, uppercaseRequest{S: s}) if err != nil { return "", err } @@ -533,15 +530,15 @@ We want to discover them through some service discovery mechanism, and spread ou And if any of those instances start to behave badly, we want to deal with that, without affecting our own service's reliability. Go kit offers adapters to different service discovery systems, to get up-to-date sets of instances, exposed as individual endpoints. -Those adapters are called publishers. +Those adapters are called subscribers. ```go -type Publisher interface { +type Subscriber interface { Endpoints() ([]endpoint.Endpoint, error) } ``` -Internally, publishers use a provided factory function to convert each discovered host:port string to a usable endpoint. +Internally, subscribers use a provided factory function to convert each discovered instance string (typically host:port) to a usable endpoint. ```go type Factory func(instance string) (endpoint.Endpoint, error) @@ -551,23 +548,19 @@ So far, our factory function, makeUppercaseEndpoint, just calls the URL directly But it's important to put some safety middleware, like circuit breakers and rate limiters, into your factory, too. ```go -func factory(ctx context.Context, maxQPS int) loadbalancer.Factory { - return func(instance string) (endpoint.Endpoint, error) { - var e endpoint.Endpoint - e = makeUppercaseProxy(ctx, instance) - e = circuitbreaker.Gobreaker(gobreaker.NewCircuitBreaker(gobreaker.Settings{}))(e) - e = kitratelimit.NewTokenBucketLimiter(jujuratelimit.NewBucketWithRate(float64(maxQPS), int64(maxQPS)))(e) - return e, nil - } +var e endpoint.Endpoint +e = makeUppercaseProxy(ctx, instance) +e = circuitbreaker.Gobreaker(gobreaker.NewCircuitBreaker(gobreaker.Settings{}))(e) +e = kitratelimit.NewTokenBucketLimiter(jujuratelimit.NewBucketWithRate(float64(maxQPS), int64(maxQPS)))(e) } ``` Now that we've got a set of endpoints, we need to choose one. -Load balancers wrap publishers, and select one endpoint from many. +Load balancers wrap subscribers, and select one endpoint from many. Go kit provides a couple of basic load balancers, and it's easy to write your own if you want more advanced heuristics. ```go -type LoadBalancer interface { +type Balancer interface { Endpoint() (endpoint.Endpoint, error) } ``` @@ -578,24 +571,52 @@ A retry strategy wraps a load balancer, and returns a usable endpoint. The retry strategy will retry failed requests until either the max attempts or timeout has been reached. ```go -func Retry(max int, timeout time.Duration, lb LoadBalancer) endpoint.Endpoint +func Retry(max int, timeout time.Duration, lb Balancer) endpoint.Endpoint ``` Let's wire up our final proxying middleware. For simplicity, we'll assume the user will specify multiple comma-separate instance endpoints with a flag. ```go -func proxyingMiddleware(proxyList string, ctx context.Context, logger log.Logger) ServiceMiddleware { +func proxyingMiddleware(instances string, ctx context.Context, logger log.Logger) ServiceMiddleware { + // If instances is empty, don't proxy. + if instances == "" { + logger.Log("proxy_to", "none") + return func(next StringService) StringService { return next } + } + + // Set some parameters for our client. + var ( + qps = 100 // beyond which we will return an error + maxAttempts = 3 // per request, before giving up + maxTime = 250 * time.Millisecond // wallclock time, before giving up + ) + + // Otherwise, construct an endpoint for each instance in the list, and add + // it to a fixed set of endpoints. In a real service, rather than doing this + // by hand, you'd probably use package sd's support for your service + // discovery system. + var ( + instanceList = split(instances) + subscriber sd.FixedSubscriber + ) + logger.Log("proxy_to", fmt.Sprint(instanceList)) + for _, instance := range instanceList { + var e endpoint.Endpoint + e = makeUppercaseProxy(ctx, instance) + e = circuitbreaker.Gobreaker(gobreaker.NewCircuitBreaker(gobreaker.Settings{}))(e) + e = kitratelimit.NewTokenBucketLimiter(jujuratelimit.NewBucketWithRate(float64(qps), int64(qps)))(e) + subscriber = append(subscriber, e) + } + + // Now, build a single, retrying, load-balancing endpoint out of all of + // those individual endpoints. + balancer := lb.NewRoundRobin(subscriber) + retry := lb.Retry(maxAttempts, maxTime, balancer) + + // And finally, return the ServiceMiddleware, implemented by proxymw. return func(next StringService) StringService { - var ( - qps = 100 // max to each instance - publisher = static.NewPublisher(split(proxyList), factory(ctx, qps), logger) - lb = loadbalancer.NewRoundRobin(publisher) - maxAttempts = 3 - maxTime = 100 * time.Millisecond - endpoint = loadbalancer.Retry(maxAttempts, maxTime, lb) - ) - return proxymw{ctx, endpoint, next} + return proxymw{ctx, next, retry} } } ``` @@ -667,13 +688,15 @@ See [package tracing](https://github.com/go-kit/kit/blob/master/tracing) for mor It's possible to use Go kit to create a client package to your service, to make consuming your service easier from other Go programs. Effectively, your client package will provide an implementation of your service interface, which invokes a remote service instance using a specific transport. -See [package addsvc/client](https://github.com/go-kit/kit/tree/master/examples/addsvc/client) for an example. +See [package addsvc/client](https://github.com/go-kit/kit/tree/master/examples/addsvc/client) + or [package profilesvc/client](https://github.com/go-kit/kit/tree/master/examples/profilesvc/client) + for examples. ## Other examples ### addsvc -[addsvc](https://github.com/go-kit/kit/blob/master/examples/addsvc) was the original example application. +[addsvc](https://github.com/go-kit/kit/blob/master/examples/addsvc) is the original example service. It exposes a set of operations over **all supported transports**. It's fully logged, instrumented, and uses Zipkin request tracing. It also demonstrates how to create and use client packages. diff --git a/examples/addsvc/client/client.go b/examples/addsvc/client/client.go deleted file mode 100644 index 7eb2968a8..000000000 --- a/examples/addsvc/client/client.go +++ /dev/null @@ -1,71 +0,0 @@ -package main - -import ( - "golang.org/x/net/context" - - "github.com/go-kit/kit/endpoint" - "github.com/go-kit/kit/examples/addsvc/server" - "github.com/go-kit/kit/log" -) - -// NewClient returns an AddService that's backed by the provided Endpoints -func newClient(ctx context.Context, sumEndpoint endpoint.Endpoint, concatEndpoint endpoint.Endpoint, logger log.Logger) server.AddService { - return client{ - Context: ctx, - Logger: logger, - sum: sumEndpoint, - concat: concatEndpoint, - } -} - -type client struct { - context.Context - log.Logger - sum endpoint.Endpoint - concat endpoint.Endpoint -} - -// TODO(pb): If your service interface methods don't return an error, we have -// no way to signal problems with a service client. If they don't take a -// context, we have to provide a global context for any transport that -// requires one, effectively making your service a black box to any context- -// specific information. So, we should make some recommendations: -// -// - To get started, a simple service interface is probably fine. -// -// - To properly deal with transport errors, every method on your service -// should return an error. This is probably important. -// -// - To properly deal with context information, every method on your service -// can take a context as its first argument. This may or may not be -// important. - -func (c client) Sum(a, b int) int { - request := server.SumRequest{ - A: a, - B: b, - } - reply, err := c.sum(c.Context, request) - if err != nil { - c.Logger.Log("err", err) // Without an error return parameter, we can't do anything else... - return 0 - } - - r := reply.(server.SumResponse) - return r.V -} - -func (c client) Concat(a, b string) string { - request := server.ConcatRequest{ - A: a, - B: b, - } - reply, err := c.concat(c.Context, request) - if err != nil { - c.Logger.Log("err", err) // Without an error return parameter, we can't do anything else... - return "" - } - - r := reply.(server.ConcatResponse) - return r.V -} diff --git a/examples/addsvc/client/grpc/client.go b/examples/addsvc/client/grpc/client.go new file mode 100644 index 000000000..e72bbe4de --- /dev/null +++ b/examples/addsvc/client/grpc/client.go @@ -0,0 +1,75 @@ +// Package grpc provides a gRPC client for the add service. +package grpc + +import ( + "time" + + jujuratelimit "github.com/juju/ratelimit" + stdopentracing "github.com/opentracing/opentracing-go" + "github.com/sony/gobreaker" + "google.golang.org/grpc" + + "github.com/go-kit/kit/circuitbreaker" + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/examples/addsvc" + "github.com/go-kit/kit/examples/addsvc/pb" + "github.com/go-kit/kit/log" + "github.com/go-kit/kit/ratelimit" + "github.com/go-kit/kit/tracing/opentracing" + grpctransport "github.com/go-kit/kit/transport/grpc" +) + +// New returns an AddService backed by a gRPC client connection. It is the +// responsibility of the caller to dial, and later close, the connection. +func New(conn *grpc.ClientConn, tracer stdopentracing.Tracer, logger log.Logger) addsvc.Service { + // We construct a single ratelimiter middleware, to limit the total outgoing + // QPS from this client to all methods on the remote instance. We also + // construct per-endpoint circuitbreaker middlewares to demonstrate how + // that's done, although they could easily be combined into a single breaker + // for the entire remote instance, too. + + limiter := ratelimit.NewTokenBucketLimiter(jujuratelimit.NewBucketWithRate(100, 100)) + + var sumEndpoint endpoint.Endpoint + { + sumEndpoint = grpctransport.NewClient( + conn, + "Add", + "Sum", + addsvc.EncodeGRPCSumRequest, + addsvc.DecodeGRPCSumResponse, + pb.SumReply{}, + grpctransport.SetClientBefore(opentracing.FromGRPCRequest(tracer, "Sum", logger)), + ).Endpoint() + sumEndpoint = opentracing.TraceClient(tracer, "Sum")(sumEndpoint) + sumEndpoint = limiter(sumEndpoint) + sumEndpoint = circuitbreaker.Gobreaker(gobreaker.NewCircuitBreaker(gobreaker.Settings{ + Name: "Sum", + Timeout: 30 * time.Second, + }))(sumEndpoint) + } + + var concatEndpoint endpoint.Endpoint + { + concatEndpoint = grpctransport.NewClient( + conn, + "Add", + "Concat", + addsvc.EncodeGRPCConcatRequest, + addsvc.DecodeGRPCConcatResponse, + pb.ConcatReply{}, + grpctransport.SetClientBefore(opentracing.FromGRPCRequest(tracer, "Concat", logger)), + ).Endpoint() + concatEndpoint = opentracing.TraceClient(tracer, "Concat")(concatEndpoint) + concatEndpoint = limiter(concatEndpoint) + sumEndpoint = circuitbreaker.Gobreaker(gobreaker.NewCircuitBreaker(gobreaker.Settings{ + Name: "Concat", + Timeout: 30 * time.Second, + }))(sumEndpoint) + } + + return addsvc.Endpoints{ + SumEndpoint: sumEndpoint, + ConcatEndpoint: concatEndpoint, + } +} diff --git a/examples/addsvc/client/grpc/encode_decode.go b/examples/addsvc/client/grpc/encode_decode.go deleted file mode 100644 index 23e081d06..000000000 --- a/examples/addsvc/client/grpc/encode_decode.go +++ /dev/null @@ -1,38 +0,0 @@ -package grpc - -import ( - "golang.org/x/net/context" - - "github.com/go-kit/kit/examples/addsvc/pb" - "github.com/go-kit/kit/examples/addsvc/server" -) - -func encodeSumRequest(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(server.SumRequest) - return &pb.SumRequest{ - A: int64(req.A), - B: int64(req.B), - }, nil -} - -func encodeConcatRequest(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(server.ConcatRequest) - return &pb.ConcatRequest{ - A: req.A, - B: req.B, - }, nil -} - -func decodeSumResponse(ctx context.Context, response interface{}) (interface{}, error) { - resp := response.(*pb.SumReply) - return server.SumResponse{ - V: int(resp.V), - }, nil -} - -func decodeConcatResponse(ctx context.Context, response interface{}) (interface{}, error) { - resp := response.(*pb.ConcatReply) - return server.ConcatResponse{ - V: resp.V, - }, nil -} diff --git a/examples/addsvc/client/grpc/factory.go b/examples/addsvc/client/grpc/factory.go deleted file mode 100644 index 0e5993268..000000000 --- a/examples/addsvc/client/grpc/factory.go +++ /dev/null @@ -1,51 +0,0 @@ -package grpc - -import ( - "io" - - kitot "github.com/go-kit/kit/tracing/opentracing" - "github.com/opentracing/opentracing-go" - "google.golang.org/grpc" - - "github.com/go-kit/kit/endpoint" - "github.com/go-kit/kit/examples/addsvc/pb" - "github.com/go-kit/kit/loadbalancer" - "github.com/go-kit/kit/log" - grpctransport "github.com/go-kit/kit/transport/grpc" -) - -// MakeSumEndpointFactory returns a loadbalancer.Factory that transforms GRPC -// host:port strings into Endpoints that call the Sum method on a GRPC server -// at that address. -func MakeSumEndpointFactory(tracer opentracing.Tracer, tracingLogger log.Logger) loadbalancer.Factory { - return func(instance string) (endpoint.Endpoint, io.Closer, error) { - cc, err := grpc.Dial(instance, grpc.WithInsecure()) - return grpctransport.NewClient( - cc, - "Add", - "Sum", - encodeSumRequest, - decodeSumResponse, - pb.SumReply{}, - grpctransport.SetClientBefore(kitot.ToGRPCRequest(tracer, tracingLogger)), - ).Endpoint(), cc, err - } -} - -// MakeConcatEndpointFactory returns a loadbalancer.Factory that transforms -// GRPC host:port strings into Endpoints that call the Concat method on a GRPC -// server at that address. -func MakeConcatEndpointFactory(tracer opentracing.Tracer, tracingLogger log.Logger) loadbalancer.Factory { - return func(instance string) (endpoint.Endpoint, io.Closer, error) { - cc, err := grpc.Dial(instance, grpc.WithInsecure()) - return grpctransport.NewClient( - cc, - "Add", - "Concat", - encodeConcatRequest, - decodeConcatResponse, - pb.ConcatReply{}, - grpctransport.SetClientBefore(kitot.ToGRPCRequest(tracer, tracingLogger)), - ).Endpoint(), cc, err - } -} diff --git a/examples/addsvc/client/http/client.go b/examples/addsvc/client/http/client.go new file mode 100644 index 000000000..c597c98b9 --- /dev/null +++ b/examples/addsvc/client/http/client.go @@ -0,0 +1,86 @@ +// Package http provides an HTTP client for the add service. +package http + +import ( + "net/url" + "strings" + "time" + + jujuratelimit "github.com/juju/ratelimit" + stdopentracing "github.com/opentracing/opentracing-go" + "github.com/sony/gobreaker" + + "github.com/go-kit/kit/circuitbreaker" + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/examples/addsvc" + "github.com/go-kit/kit/log" + "github.com/go-kit/kit/ratelimit" + "github.com/go-kit/kit/tracing/opentracing" + httptransport "github.com/go-kit/kit/transport/http" +) + +// New returns an AddService backed by an HTTP server living at the remote +// instance. We expect instance to come from a service discovery system, so +// likely of the form "host:port". +func New(instance string, tracer stdopentracing.Tracer, logger log.Logger) (addsvc.Service, error) { + if !strings.HasPrefix(instance, "http") { + instance = "http://" + instance + } + u, err := url.Parse(instance) + if err != nil { + return nil, err + } + + // We construct a single ratelimiter middleware, to limit the total outgoing + // QPS from this client to all methods on the remote instance. We also + // construct per-endpoint circuitbreaker middlewares to demonstrate how + // that's done, although they could easily be combined into a single breaker + // for the entire remote instance, too. + + limiter := ratelimit.NewTokenBucketLimiter(jujuratelimit.NewBucketWithRate(100, 100)) + + var sumEndpoint endpoint.Endpoint + { + sumEndpoint = httptransport.NewClient( + "POST", + copyURL(u, "/sum"), + addsvc.EncodeHTTPGenericRequest, + addsvc.DecodeHTTPSumResponse, + httptransport.SetClientBefore(opentracing.FromHTTPRequest(tracer, "Sum", logger)), + ).Endpoint() + sumEndpoint = opentracing.TraceClient(tracer, "Sum")(sumEndpoint) + sumEndpoint = limiter(sumEndpoint) + sumEndpoint = circuitbreaker.Gobreaker(gobreaker.NewCircuitBreaker(gobreaker.Settings{ + Name: "Sum", + Timeout: 30 * time.Second, + }))(sumEndpoint) + } + + var concatEndpoint endpoint.Endpoint + { + concatEndpoint = httptransport.NewClient( + "POST", + copyURL(u, "/concat"), + addsvc.EncodeHTTPGenericRequest, + addsvc.DecodeHTTPConcatResponse, + httptransport.SetClientBefore(opentracing.FromHTTPRequest(tracer, "Concat", logger)), + ).Endpoint() + concatEndpoint = opentracing.TraceClient(tracer, "Concat")(concatEndpoint) + concatEndpoint = limiter(concatEndpoint) + sumEndpoint = circuitbreaker.Gobreaker(gobreaker.NewCircuitBreaker(gobreaker.Settings{ + Name: "Concat", + Timeout: 30 * time.Second, + }))(sumEndpoint) + } + + return addsvc.Endpoints{ + SumEndpoint: sumEndpoint, + ConcatEndpoint: concatEndpoint, + }, nil +} + +func copyURL(base *url.URL, path string) *url.URL { + next := *base + next.Path = path + return &next +} diff --git a/examples/addsvc/client/httpjson/factory.go b/examples/addsvc/client/httpjson/factory.go deleted file mode 100644 index f36ce4e43..000000000 --- a/examples/addsvc/client/httpjson/factory.go +++ /dev/null @@ -1,65 +0,0 @@ -package httpjson - -import ( - "io" - "net/url" - - "github.com/opentracing/opentracing-go" - - "github.com/go-kit/kit/endpoint" - "github.com/go-kit/kit/examples/addsvc/server" - "github.com/go-kit/kit/loadbalancer" - "github.com/go-kit/kit/log" - kitot "github.com/go-kit/kit/tracing/opentracing" - httptransport "github.com/go-kit/kit/transport/http" -) - -// MakeSumEndpointFactory generates a Factory that transforms an http url into -// an Endpoint. -// -// The path of the url is reset to /sum. -func MakeSumEndpointFactory(tracer opentracing.Tracer, tracingLogger log.Logger) loadbalancer.Factory { - return func(instance string) (endpoint.Endpoint, io.Closer, error) { - sumURL, err := url.Parse(instance) - if err != nil { - return nil, nil, err - } - sumURL.Path = "/sum" - - client := httptransport.NewClient( - "GET", - sumURL, - server.EncodeSumRequest, - server.DecodeSumResponse, - httptransport.SetClient(nil), - httptransport.SetClientBefore(kitot.ToHTTPRequest(tracer, tracingLogger)), - ) - - return client.Endpoint(), nil, nil - } -} - -// MakeConcatEndpointFactory generates a Factory that transforms an http url -// into an Endpoint. -// -// The path of the url is reset to /concat. -func MakeConcatEndpointFactory(tracer opentracing.Tracer, tracingLogger log.Logger) loadbalancer.Factory { - return func(instance string) (endpoint.Endpoint, io.Closer, error) { - concatURL, err := url.Parse(instance) - if err != nil { - return nil, nil, err - } - concatURL.Path = "/concat" - - client := httptransport.NewClient( - "GET", - concatURL, - server.EncodeConcatRequest, - server.DecodeConcatResponse, - httptransport.SetClient(nil), - httptransport.SetClientBefore(kitot.ToHTTPRequest(tracer, tracingLogger)), - ) - - return client.Endpoint(), nil, nil - } -} diff --git a/examples/addsvc/client/main.go b/examples/addsvc/client/main.go deleted file mode 100644 index 8e0140978..000000000 --- a/examples/addsvc/client/main.go +++ /dev/null @@ -1,167 +0,0 @@ -package main - -import ( - "flag" - "fmt" - "os" - "path/filepath" - "strconv" - "strings" - "time" - - "github.com/lightstep/lightstep-tracer-go" - "github.com/opentracing/opentracing-go" - zipkin "github.com/openzipkin/zipkin-go-opentracing" - appdashot "github.com/sourcegraph/appdash/opentracing" - "golang.org/x/net/context" - "sourcegraph.com/sourcegraph/appdash" - - "github.com/go-kit/kit/endpoint" - grpcclient "github.com/go-kit/kit/examples/addsvc/client/grpc" - httpjsonclient "github.com/go-kit/kit/examples/addsvc/client/httpjson" - netrpcclient "github.com/go-kit/kit/examples/addsvc/client/netrpc" - thriftclient "github.com/go-kit/kit/examples/addsvc/client/thrift" - "github.com/go-kit/kit/loadbalancer" - "github.com/go-kit/kit/loadbalancer/static" - "github.com/go-kit/kit/log" - kitot "github.com/go-kit/kit/tracing/opentracing" -) - -func main() { - var ( - transport = flag.String("transport", "httpjson", "httpjson, grpc, netrpc, thrift") - httpAddrs = flag.String("http.addrs", "localhost:8001", "Comma-separated list of addresses for HTTP (JSON) servers") - grpcAddrs = flag.String("grpc.addrs", "localhost:8002", "Comma-separated list of addresses for gRPC servers") - netrpcAddrs = flag.String("netrpc.addrs", "localhost:8003", "Comma-separated list of addresses for net/rpc servers") - thriftAddrs = flag.String("thrift.addrs", "localhost:8004", "Comma-separated list of addresses for Thrift servers") - thriftProtocol = flag.String("thrift.protocol", "binary", "binary, compact, json, simplejson") - thriftBufferSize = flag.Int("thrift.buffer.size", 0, "0 for unbuffered") - thriftFramed = flag.Bool("thrift.framed", false, "true to enable framing") - - // Three OpenTracing backends (to demonstrate how they can be interchanged): - zipkinAddr = flag.String("zipkin.kafka.addr", "", "Enable Zipkin tracing via a Kafka Collector host:port") - appdashAddr = flag.String("appdash.addr", "", "Enable Appdash tracing via an Appdash server host:port") - lightstepAccessToken = flag.String("lightstep.token", "", "Enable LightStep tracing via a LightStep access token") - ) - flag.Parse() - if len(os.Args) < 4 { - fmt.Fprintf(os.Stderr, "\n%s [flags] method arg1 arg2\n\n", filepath.Base(os.Args[0])) - flag.Usage() - os.Exit(1) - } - - randomSeed := time.Now().UnixNano() - - root := context.Background() - method, s1, s2 := flag.Arg(0), flag.Arg(1), flag.Arg(2) - - var logger log.Logger - logger = log.NewLogfmtLogger(os.Stdout) - logger = log.NewContext(logger).With("caller", log.DefaultCaller) - logger = log.NewContext(logger).With("transport", *transport) - tracingLogger := log.NewContext(logger).With("component", "tracing") - - // Set up OpenTracing - var tracer opentracing.Tracer - { - switch { - case *appdashAddr != "" && *lightstepAccessToken == "" && *zipkinAddr == "": - tracer = appdashot.NewTracer(appdash.NewRemoteCollector(*appdashAddr)) - case *appdashAddr == "" && *lightstepAccessToken != "" && *zipkinAddr == "": - tracer = lightstep.NewTracer(lightstep.Options{ - AccessToken: *lightstepAccessToken, - }) - defer lightstep.FlushLightStepTracer(tracer) - case *appdashAddr == "" && *lightstepAccessToken == "" && *zipkinAddr != "": - collector, err := zipkin.NewKafkaCollector( - strings.Split(*zipkinAddr, ","), - zipkin.KafkaLogger(tracingLogger), - ) - if err != nil { - tracingLogger.Log("err", "unable to create kafka collector", "fatal", err) - os.Exit(1) - } - tracer, err = zipkin.NewTracer( - zipkin.NewRecorder(collector, false, "localhost:8000", "addsvc-client"), - ) - if err != nil { - tracingLogger.Log("err", "unable to create zipkin tracer", "fatal", err) - os.Exit(1) - } - case *appdashAddr == "" && *lightstepAccessToken == "" && *zipkinAddr == "": - tracer = opentracing.GlobalTracer() // no-op - default: - tracingLogger.Log("fatal", "specify a single -appdash.addr, -lightstep.access.token or -zipkin.kafka.addr") - os.Exit(1) - } - } - - var ( - instances []string - sumFactory, concatFactory loadbalancer.Factory - ) - - switch *transport { - case "grpc": - instances = strings.Split(*grpcAddrs, ",") - sumFactory = grpcclient.MakeSumEndpointFactory(tracer, tracingLogger) - concatFactory = grpcclient.MakeConcatEndpointFactory(tracer, tracingLogger) - - case "httpjson": - instances = strings.Split(*httpAddrs, ",") - for i, rawurl := range instances { - if !strings.HasPrefix("http", rawurl) { - instances[i] = "http://" + rawurl - } - } - sumFactory = httpjsonclient.MakeSumEndpointFactory(tracer, tracingLogger) - concatFactory = httpjsonclient.MakeConcatEndpointFactory(tracer, tracingLogger) - - case "netrpc": - instances = strings.Split(*netrpcAddrs, ",") - sumFactory = netrpcclient.SumEndpointFactory - concatFactory = netrpcclient.ConcatEndpointFactory - - case "thrift": - instances = strings.Split(*thriftAddrs, ",") - thriftClient := thriftclient.New(*thriftProtocol, *thriftBufferSize, *thriftFramed, logger) - sumFactory = thriftClient.SumEndpoint - concatFactory = thriftClient.ConcatEndpoint - - default: - logger.Log("err", "invalid transport") - os.Exit(1) - } - - sum := buildEndpoint(tracer, "sum", instances, sumFactory, randomSeed, logger) - concat := buildEndpoint(tracer, "concat", instances, concatFactory, randomSeed, logger) - - svc := newClient(root, sum, concat, logger) - - begin := time.Now() - switch method { - case "sum": - a, _ := strconv.Atoi(s1) - b, _ := strconv.Atoi(s2) - v := svc.Sum(a, b) - logger.Log("method", "sum", "a", a, "b", b, "v", v, "took", time.Since(begin)) - - case "concat": - a, b := s1, s2 - v := svc.Concat(a, b) - logger.Log("method", "concat", "a", a, "b", b, "v", v, "took", time.Since(begin)) - - default: - logger.Log("err", "invalid method "+method) - os.Exit(1) - } - // wait for collector - time.Sleep(2 * time.Second) -} - -func buildEndpoint(tracer opentracing.Tracer, operationName string, instances []string, factory loadbalancer.Factory, seed int64, logger log.Logger) endpoint.Endpoint { - publisher := static.NewPublisher(instances, factory, logger) - random := loadbalancer.NewRandom(publisher, seed) - endpoint := loadbalancer.Retry(10, 10*time.Second, random) - return kitot.TraceClient(tracer, operationName)(endpoint) -} diff --git a/examples/addsvc/client/netrpc/factory.go b/examples/addsvc/client/netrpc/factory.go deleted file mode 100644 index 9f9a87531..000000000 --- a/examples/addsvc/client/netrpc/factory.go +++ /dev/null @@ -1,43 +0,0 @@ -package netrpc - -import ( - "io" - "net/rpc" - - "golang.org/x/net/context" - - "github.com/go-kit/kit/endpoint" - "github.com/go-kit/kit/examples/addsvc/server" -) - -// SumEndpointFactory transforms host:port strings into Endpoints. -func SumEndpointFactory(instance string) (endpoint.Endpoint, io.Closer, error) { - client, err := rpc.DialHTTP("tcp", instance) - if err != nil { - return nil, nil, err - } - - return func(ctx context.Context, request interface{}) (interface{}, error) { - var reply server.SumResponse - if err := client.Call("addsvc.Sum", request.(server.SumRequest), &reply); err != nil { - return server.SumResponse{}, err - } - return reply, nil - }, client, nil -} - -// ConcatEndpointFactory transforms host:port strings into Endpoints. -func ConcatEndpointFactory(instance string) (endpoint.Endpoint, io.Closer, error) { - client, err := rpc.DialHTTP("tcp", instance) - if err != nil { - return nil, nil, err - } - - return func(ctx context.Context, request interface{}) (interface{}, error) { - var reply server.ConcatResponse - if err := client.Call("addsvc.Concat", request.(server.ConcatRequest), &reply); err != nil { - return server.ConcatResponse{}, err - } - return reply, nil - }, client, nil -} diff --git a/examples/addsvc/client/thrift/client.go b/examples/addsvc/client/thrift/client.go index 2234a22ff..a943c7b5c 100644 --- a/examples/addsvc/client/thrift/client.go +++ b/examples/addsvc/client/thrift/client.go @@ -1,97 +1,55 @@ +// Package thrift provides a Thrift client for the add service. package thrift import ( - "io" + "time" - "github.com/apache/thrift/lib/go/thrift" + jujuratelimit "github.com/juju/ratelimit" + "github.com/sony/gobreaker" + + "github.com/go-kit/kit/circuitbreaker" "github.com/go-kit/kit/endpoint" - "github.com/go-kit/kit/examples/addsvc/server" - thriftadd "github.com/go-kit/kit/examples/addsvc/thrift/gen-go/add" - "github.com/go-kit/kit/log" - "golang.org/x/net/context" + "github.com/go-kit/kit/examples/addsvc" + thriftadd "github.com/go-kit/kit/examples/addsvc/thrift/gen-go/addsvc" + "github.com/go-kit/kit/ratelimit" ) -// New returns a stateful factory for Sum and Concat Endpoints -func New(protocol string, bufferSize int, framed bool, logger log.Logger) client { - var protocolFactory thrift.TProtocolFactory - switch protocol { - case "compact": - protocolFactory = thrift.NewTCompactProtocolFactory() - case "simplejson": - protocolFactory = thrift.NewTSimpleJSONProtocolFactory() - case "json": - protocolFactory = thrift.NewTJSONProtocolFactory() - case "binary", "": - protocolFactory = thrift.NewTBinaryProtocolFactoryDefault() - default: - panic("invalid protocol") - } - - var transportFactory thrift.TTransportFactory - if bufferSize > 0 { - transportFactory = thrift.NewTBufferedTransportFactory(bufferSize) - } else { - transportFactory = thrift.NewTTransportFactory() - } - if framed { - transportFactory = thrift.NewTFramedTransportFactory(transportFactory) +// New returns an AddService backed by a Thrift server described by the provided +// client. The caller is responsible for constructing the client, and eventually +// closing the underlying transport. +func New(client *thriftadd.AddServiceClient) addsvc.Service { + // We construct a single ratelimiter middleware, to limit the total outgoing + // QPS from this client to all methods on the remote instance. We also + // construct per-endpoint circuitbreaker middlewares to demonstrate how + // that's done, although they could easily be combined into a single breaker + // for the entire remote instance, too. + + limiter := ratelimit.NewTokenBucketLimiter(jujuratelimit.NewBucketWithRate(100, 100)) + + // Thrift does not currently have tracer bindings, so we skip tracing. + + var sumEndpoint endpoint.Endpoint + { + sumEndpoint = addsvc.MakeThriftSumEndpoint(client) + sumEndpoint = limiter(sumEndpoint) + sumEndpoint = circuitbreaker.Gobreaker(gobreaker.NewCircuitBreaker(gobreaker.Settings{ + Name: "Sum", + Timeout: 30 * time.Second, + }))(sumEndpoint) } - return client{transportFactory, protocolFactory, logger} -} - -type client struct { - thrift.TTransportFactory - thrift.TProtocolFactory - log.Logger -} - -// SumEndpointFactory transforms host:port strings into Endpoints. -func (c client) SumEndpoint(instance string) (endpoint.Endpoint, io.Closer, error) { - transportSocket, err := thrift.NewTSocket(instance) - if err != nil { - c.Logger.Log("during", "thrift.NewTSocket", "err", err) - return nil, nil, err + var concatEndpoint endpoint.Endpoint + { + concatEndpoint = addsvc.MakeThriftConcatEndpoint(client) + concatEndpoint = limiter(concatEndpoint) + sumEndpoint = circuitbreaker.Gobreaker(gobreaker.NewCircuitBreaker(gobreaker.Settings{ + Name: "Concat", + Timeout: 30 * time.Second, + }))(sumEndpoint) } - trans := c.TTransportFactory.GetTransport(transportSocket) - if err := trans.Open(); err != nil { - c.Logger.Log("during", "thrift transport.Open", "err", err) - return nil, nil, err + return addsvc.Endpoints{ + SumEndpoint: addsvc.MakeThriftSumEndpoint(client), + ConcatEndpoint: addsvc.MakeThriftConcatEndpoint(client), } - cli := thriftadd.NewAddServiceClientFactory(trans, c.TProtocolFactory) - - return func(ctx context.Context, request interface{}) (interface{}, error) { - sumRequest := request.(server.SumRequest) - reply, err := cli.Sum(int64(sumRequest.A), int64(sumRequest.B)) - if err != nil { - return server.SumResponse{}, err - } - return server.SumResponse{V: int(reply.Value)}, nil - }, trans, nil -} - -// ConcatEndpointFactory transforms host:port strings into Endpoints. -func (c client) ConcatEndpoint(instance string) (endpoint.Endpoint, io.Closer, error) { - transportSocket, err := thrift.NewTSocket(instance) - if err != nil { - c.Logger.Log("during", "thrift.NewTSocket", "err", err) - return nil, nil, err - } - trans := c.TTransportFactory.GetTransport(transportSocket) - - if err := trans.Open(); err != nil { - c.Logger.Log("during", "thrift transport.Open", "err", err) - return nil, nil, err - } - cli := thriftadd.NewAddServiceClientFactory(trans, c.TProtocolFactory) - - return func(ctx context.Context, request interface{}) (interface{}, error) { - concatRequest := request.(server.ConcatRequest) - reply, err := cli.Concat(concatRequest.A, concatRequest.B) - if err != nil { - return server.ConcatResponse{}, err - } - return server.ConcatResponse{V: reply.Value}, nil - }, trans, nil } diff --git a/examples/addsvc/cmd/addcli/main.go b/examples/addsvc/cmd/addcli/main.go new file mode 100644 index 000000000..870bfa8f1 --- /dev/null +++ b/examples/addsvc/cmd/addcli/main.go @@ -0,0 +1,178 @@ +package main + +import ( + "flag" + "fmt" + "os" + "strconv" + "strings" + "time" + + "github.com/apache/thrift/lib/go/thrift" + "github.com/lightstep/lightstep-tracer-go" + stdopentracing "github.com/opentracing/opentracing-go" + zipkin "github.com/openzipkin/zipkin-go-opentracing" + appdashot "github.com/sourcegraph/appdash/opentracing" + "golang.org/x/net/context" + "google.golang.org/grpc" + "sourcegraph.com/sourcegraph/appdash" + + "github.com/go-kit/kit/examples/addsvc" + grpcclient "github.com/go-kit/kit/examples/addsvc/client/grpc" + httpclient "github.com/go-kit/kit/examples/addsvc/client/http" + thriftclient "github.com/go-kit/kit/examples/addsvc/client/thrift" + thriftadd "github.com/go-kit/kit/examples/addsvc/thrift/gen-go/addsvc" + "github.com/go-kit/kit/log" +) + +func main() { + // The addcli presumes no service discovery system, and expects users to + // provide the direct address of an addsvc. This presumption is reflected in + // the addcli binary and the the client packages: the -transport.addr flags + // and various client constructors both expect host:port strings. For an + // example service with a client built on top of a service discovery system, + // see profilesvc. + + var ( + httpAddr = flag.String("http.addr", "", "HTTP address of addsvc") + grpcAddr = flag.String("grpc.addr", "", "gRPC (HTTP) address of addsvc") + thriftAddr = flag.String("thrift.addr", "", "Thrift address of addsvc") + thriftProtocol = flag.String("thrift.protocol", "binary", "binary, compact, json, simplejson") + thriftBufferSize = flag.Int("thrift.buffer.size", 0, "0 for unbuffered") + thriftFramed = flag.Bool("thrift.framed", false, "true to enable framing") + zipkinAddr = flag.String("zipkin.addr", "", "Enable Zipkin tracing via a Kafka Collector host:port") + appdashAddr = flag.String("appdash.addr", "", "Enable Appdash tracing via an Appdash server host:port") + lightstepToken = flag.String("lightstep.token", "", "Enable LightStep tracing via a LightStep access token") + method = flag.String("method", "sum", "sum, concat") + ) + flag.Parse() + + if len(flag.Args()) != 2 { + fmt.Fprintf(os.Stderr, "usage: addcli [flags] \n") + os.Exit(1) + } + + // This is a demonstration client, which supports multiple tracers. + // Your clients will probably just use one tracer. + var tracer stdopentracing.Tracer + { + if *zipkinAddr != "" { + collector, err := zipkin.NewKafkaCollector( + strings.Split(*zipkinAddr, ","), + zipkin.KafkaLogger(log.NewNopLogger()), + ) + if err != nil { + fmt.Fprintf(os.Stderr, "%v\n", err) + os.Exit(1) + } + tracer, err = zipkin.NewTracer( + zipkin.NewRecorder(collector, false, "localhost:8000", "addcli"), + ) + if err != nil { + fmt.Fprintf(os.Stderr, "%v\n", err) + os.Exit(1) + } + } else if *appdashAddr != "" { + tracer = appdashot.NewTracer(appdash.NewRemoteCollector(*appdashAddr)) + } else if *lightstepToken != "" { + tracer = lightstep.NewTracer(lightstep.Options{ + AccessToken: *lightstepToken, + }) + defer lightstep.FlushLightStepTracer(tracer) + } else { + tracer = stdopentracing.GlobalTracer() // no-op + } + } + + // This is a demonstration client, which supports multiple transports. + // Your clients will probably just define and stick with 1 transport. + + var ( + service addsvc.Service + err error + ) + if *httpAddr != "" { + service, err = httpclient.New(*httpAddr, tracer, log.NewNopLogger()) + } else if *grpcAddr != "" { + conn, err := grpc.Dial(*grpcAddr, grpc.WithInsecure(), grpc.WithTimeout(time.Second)) + if err != nil { + fmt.Fprintf(os.Stderr, "error: %v", err) + os.Exit(1) + } + defer conn.Close() + service = grpcclient.New(conn, tracer, log.NewNopLogger()) + } else if *thriftAddr != "" { + // It's necessary to do all of this construction in the func main, + // because (among other reasons) we need to control the lifecycle of the + // Thrift transport, i.e. close it eventually. + var protocolFactory thrift.TProtocolFactory + switch *thriftProtocol { + case "compact": + protocolFactory = thrift.NewTCompactProtocolFactory() + case "simplejson": + protocolFactory = thrift.NewTSimpleJSONProtocolFactory() + case "json": + protocolFactory = thrift.NewTJSONProtocolFactory() + case "binary", "": + protocolFactory = thrift.NewTBinaryProtocolFactoryDefault() + default: + fmt.Fprintf(os.Stderr, "error: invalid protocol %q\n", *thriftProtocol) + os.Exit(1) + } + var transportFactory thrift.TTransportFactory + if *thriftBufferSize > 0 { + transportFactory = thrift.NewTBufferedTransportFactory(*thriftBufferSize) + } else { + transportFactory = thrift.NewTTransportFactory() + } + if *thriftFramed { + transportFactory = thrift.NewTFramedTransportFactory(transportFactory) + } + transportSocket, err := thrift.NewTSocket(*thriftAddr) + if err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + transport := transportFactory.GetTransport(transportSocket) + if err := transport.Open(); err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + defer transport.Close() + client := thriftadd.NewAddServiceClientFactory(transport, protocolFactory) + service = thriftclient.New(client) + } else { + fmt.Fprintf(os.Stderr, "error: no remote address specified\n") + os.Exit(1) + } + if err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + + switch *method { + case "sum": + a, _ := strconv.ParseInt(flag.Args()[0], 10, 64) + b, _ := strconv.ParseInt(flag.Args()[1], 10, 64) + v, err := service.Sum(context.Background(), int(a), int(b)) + if err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + fmt.Fprintf(os.Stdout, "%d + %d = %d\n", a, b, v) + + case "concat": + a := flag.Args()[0] + b := flag.Args()[1] + v, err := service.Concat(context.Background(), a, b) + if err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + fmt.Fprintf(os.Stdout, "%q + %q = %q\n", a, b, v) + + default: + fmt.Fprintf(os.Stderr, "error: invalid method %q\n", method) + os.Exit(1) + } +} diff --git a/examples/addsvc/cmd/addsvc/main.go b/examples/addsvc/cmd/addsvc/main.go new file mode 100644 index 000000000..2273a4c75 --- /dev/null +++ b/examples/addsvc/cmd/addsvc/main.go @@ -0,0 +1,257 @@ +package main + +import ( + "flag" + "fmt" + "net" + "net/http" + "net/http/pprof" + "os" + "os/signal" + "strings" + "syscall" + "time" + + "github.com/apache/thrift/lib/go/thrift" + lightstep "github.com/lightstep/lightstep-tracer-go" + stdopentracing "github.com/opentracing/opentracing-go" + zipkin "github.com/openzipkin/zipkin-go-opentracing" + stdprometheus "github.com/prometheus/client_golang/prometheus" + appdashot "github.com/sourcegraph/appdash/opentracing" + "golang.org/x/net/context" + "google.golang.org/grpc" + "sourcegraph.com/sourcegraph/appdash" + + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/examples/addsvc" + "github.com/go-kit/kit/examples/addsvc/pb" + thriftadd "github.com/go-kit/kit/examples/addsvc/thrift/gen-go/addsvc" + "github.com/go-kit/kit/log" + "github.com/go-kit/kit/metrics" + "github.com/go-kit/kit/metrics/prometheus" + "github.com/go-kit/kit/tracing/opentracing" +) + +func main() { + var ( + debugAddr = flag.String("debug.addr", ":8080", "Debug and metrics listen address") + httpAddr = flag.String("http.addr", ":8081", "HTTP listen address") + grpcAddr = flag.String("grpc.addr", ":8082", "gRPC (HTTP) listen address") + thriftAddr = flag.String("thrift.addr", ":8083", "Thrift listen address") + thriftProtocol = flag.String("thrift.protocol", "binary", "binary, compact, json, simplejson") + thriftBufferSize = flag.Int("thrift.buffer.size", 0, "0 for unbuffered") + thriftFramed = flag.Bool("thrift.framed", false, "true to enable framing") + zipkinAddr = flag.String("zipkin.addr", "", "Enable Zipkin tracing via a Kafka server host:port") + appdashAddr = flag.String("appdash.addr", "", "Enable Appdash tracing via an Appdash server host:port") + lightstepToken = flag.String("lightstep.token", "", "Enable LightStep tracing via a LightStep access token") + ) + flag.Parse() + + // Logging domain. + var logger log.Logger + { + logger = log.NewLogfmtLogger(os.Stdout) + logger = log.NewContext(logger).With("ts", log.DefaultTimestampUTC) + logger = log.NewContext(logger).With("caller", log.DefaultCaller) + } + logger.Log("msg", "hello") + defer logger.Log("msg", "goodbye") + + // Metrics domain. + var ints, chars metrics.Counter + { + // Business level metrics. + ints = prometheus.NewCounter(stdprometheus.CounterOpts{ + Namespace: "addsvc", + Name: "integers_summed", + Help: "Total count of integers summed via the Sum method.", + }, []string{}) + chars = prometheus.NewCounter(stdprometheus.CounterOpts{ + Namespace: "addsvc", + Name: "characters_concatenated", + Help: "Total count of characters concatenated via the Concat method.", + }, []string{}) + } + var duration metrics.TimeHistogram + { + // Transport level metrics. + duration = metrics.NewTimeHistogram(time.Nanosecond, prometheus.NewSummary(stdprometheus.SummaryOpts{ + Namespace: "addsvc", + Name: "request_duration_ns", + Help: "Request duration in nanoseconds.", + }, []string{"method", "success"})) + } + + // Tracing domain. + var tracer stdopentracing.Tracer + { + if *zipkinAddr != "" { + logger := log.NewContext(logger).With("tracer", "Zipkin") + logger.Log("addr", *zipkinAddr) + collector, err := zipkin.NewKafkaCollector( + strings.Split(*zipkinAddr, ","), + zipkin.KafkaLogger(logger), + ) + if err != nil { + logger.Log("err", err) + os.Exit(1) + } + tracer, err = zipkin.NewTracer( + zipkin.NewRecorder(collector, false, "localhost:80", "addsvc"), + ) + if err != nil { + logger.Log("err", err) + os.Exit(1) + } + } else if *appdashAddr != "" { + logger := log.NewContext(logger).With("tracer", "Appdash") + logger.Log("addr", *appdashAddr) + tracer = appdashot.NewTracer(appdash.NewRemoteCollector(*appdashAddr)) + } else if *lightstepToken != "" { + logger := log.NewContext(logger).With("tracer", "LightStep") + logger.Log() // probably don't want to print out the token :) + tracer = lightstep.NewTracer(lightstep.Options{ + AccessToken: *lightstepToken, + }) + defer lightstep.FlushLightStepTracer(tracer) + } else { + logger := log.NewContext(logger).With("tracer", "none") + logger.Log() + tracer = stdopentracing.GlobalTracer() // no-op + } + } + + // Business domain. + var service addsvc.Service + { + service = addsvc.NewBasicService() + service = addsvc.ServiceLoggingMiddleware(logger)(service) + service = addsvc.ServiceInstrumentingMiddleware(ints, chars)(service) + } + + // Endpoint domain. + var sumEndpoint endpoint.Endpoint + { + sumDuration := duration.With(metrics.Field{Key: "method", Value: "Sum"}) + sumLogger := log.NewContext(logger).With("method", "Sum") + + sumEndpoint = addsvc.MakeSumEndpoint(service) + sumEndpoint = opentracing.TraceServer(tracer, "Sum")(sumEndpoint) + sumEndpoint = addsvc.EndpointInstrumentingMiddleware(sumDuration)(sumEndpoint) + sumEndpoint = addsvc.EndpointLoggingMiddleware(sumLogger)(sumEndpoint) + } + var concatEndpoint endpoint.Endpoint + { + concatDuration := duration.With(metrics.Field{Key: "method", Value: "Concat"}) + concatLogger := log.NewContext(logger).With("method", "Concat") + + concatEndpoint = addsvc.MakeConcatEndpoint(service) + concatEndpoint = opentracing.TraceServer(tracer, "Concat")(concatEndpoint) + concatEndpoint = addsvc.EndpointInstrumentingMiddleware(concatDuration)(concatEndpoint) + concatEndpoint = addsvc.EndpointLoggingMiddleware(concatLogger)(concatEndpoint) + } + endpoints := addsvc.Endpoints{ + SumEndpoint: sumEndpoint, + ConcatEndpoint: concatEndpoint, + } + + // Mechanical domain. + errc := make(chan error) + ctx := context.Background() + + // Interrupt handler. + go func() { + c := make(chan os.Signal, 1) + signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) + errc <- fmt.Errorf("%s", <-c) + }() + + // Debug listener. + go func() { + logger := log.NewContext(logger).With("transport", "debug") + + m := http.NewServeMux() + m.Handle("/debug/pprof/", http.HandlerFunc(pprof.Index)) + m.Handle("/debug/pprof/cmdline", http.HandlerFunc(pprof.Cmdline)) + m.Handle("/debug/pprof/profile", http.HandlerFunc(pprof.Profile)) + m.Handle("/debug/pprof/symbol", http.HandlerFunc(pprof.Symbol)) + m.Handle("/debug/pprof/trace", http.HandlerFunc(pprof.Trace)) + m.Handle("/metrics", stdprometheus.Handler()) + + logger.Log("addr", *debugAddr) + errc <- http.ListenAndServe(*debugAddr, m) + }() + + // HTTP transport. + go func() { + logger := log.NewContext(logger).With("transport", "HTTP") + h := addsvc.MakeHTTPHandler(ctx, endpoints, tracer, logger) + logger.Log("addr", *httpAddr) + errc <- http.ListenAndServe(*httpAddr, h) + }() + + // gRPC transport. + go func() { + logger := log.NewContext(logger).With("transport", "gRPC") + + ln, err := net.Listen("tcp", *grpcAddr) + if err != nil { + errc <- err + return + } + + srv := addsvc.MakeGRPCServer(ctx, endpoints, tracer, logger) + s := grpc.NewServer() + pb.RegisterAddServer(s, srv) + + logger.Log("addr", *grpcAddr) + errc <- s.Serve(ln) + }() + + // Thrift transport. + go func() { + logger := log.NewContext(logger).With("transport", "Thrift") + + var protocolFactory thrift.TProtocolFactory + switch *thriftProtocol { + case "binary": + protocolFactory = thrift.NewTBinaryProtocolFactoryDefault() + case "compact": + protocolFactory = thrift.NewTCompactProtocolFactory() + case "json": + protocolFactory = thrift.NewTJSONProtocolFactory() + case "simplejson": + protocolFactory = thrift.NewTSimpleJSONProtocolFactory() + default: + errc <- fmt.Errorf("invalid Thrift protocol %q", *thriftProtocol) + return + } + + var transportFactory thrift.TTransportFactory + if *thriftBufferSize > 0 { + transportFactory = thrift.NewTBufferedTransportFactory(*thriftBufferSize) + } else { + transportFactory = thrift.NewTTransportFactory() + } + if *thriftFramed { + transportFactory = thrift.NewTFramedTransportFactory(transportFactory) + } + + transport, err := thrift.NewTServerSocket(*thriftAddr) + if err != nil { + errc <- err + return + } + + logger.Log("addr", *thriftAddr) + errc <- thrift.NewTSimpleServer4( + thriftadd.NewAddServiceProcessor(addsvc.MakeThriftHandler(ctx, endpoints)), + transport, + transportFactory, + protocolFactory, + ).Serve() + }() + + // Run! + logger.Log("exit", <-errc) +} diff --git a/examples/addsvc/doc.go b/examples/addsvc/doc.go new file mode 100644 index 000000000..8865046fb --- /dev/null +++ b/examples/addsvc/doc.go @@ -0,0 +1,6 @@ +// Package addsvc implements the business and transport logic for an example +// service that can sum integers and concatenate strings. +// +// A client library is available in the client subdirectory. A server binary is +// available in cmd/addsrv. An example client binary is available in cmd/addcli. +package addsvc diff --git a/examples/addsvc/endpoint.go b/examples/addsvc/endpoint.go deleted file mode 100644 index 86f58732f..000000000 --- a/examples/addsvc/endpoint.go +++ /dev/null @@ -1,24 +0,0 @@ -package main - -import ( - "golang.org/x/net/context" - - "github.com/go-kit/kit/endpoint" - "github.com/go-kit/kit/examples/addsvc/server" -) - -func makeSumEndpoint(svc server.AddService) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(*server.SumRequest) - v := svc.Sum(req.A, req.B) - return server.SumResponse{V: v}, nil - } -} - -func makeConcatEndpoint(svc server.AddService) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(*server.ConcatRequest) - v := svc.Concat(req.A, req.B) - return server.ConcatResponse{V: v}, nil - } -} diff --git a/examples/addsvc/endpoints.go b/examples/addsvc/endpoints.go new file mode 100644 index 000000000..e46d33b7b --- /dev/null +++ b/examples/addsvc/endpoints.go @@ -0,0 +1,131 @@ +package addsvc + +// This file contains methods to make individual endpoints from services, +// request and response types to serve those endpoints, as well as encoders and +// decoders for those types, for all of our supported transport serialization +// formats. It also includes endpoint middlewares. + +import ( + "fmt" + "time" + + "golang.org/x/net/context" + + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/log" + "github.com/go-kit/kit/metrics" +) + +// Endpoints collects all of the endpoints that compose an add service. It's +// meant to be used as a helper struct, to collect all of the endpoints into a +// single parameter. +// +// In a server, it's useful for functions that need to operate on a per-endpoint +// basis. For example, you might pass an Endpoints to a function that produces +// an http.Handler, with each method (endpoint) wired up to a specific path. (It +// is probably a mistake in design to invoke the Service methods on the +// Endpoints struct in a server.) +// +// In a client, it's useful to collect individually constructed endpoints into a +// single type that implements the Service interface. For example, you might +// construct individual endpoints using transport/http.NewClient, combine them +// into an Endpoints, and return it to the caller as a Service. +type Endpoints struct { + SumEndpoint endpoint.Endpoint + ConcatEndpoint endpoint.Endpoint +} + +// Sum implements Service. Primarily useful in a client. +func (e Endpoints) Sum(ctx context.Context, a, b int) (int, error) { + request := sumRequest{A: a, B: b} + response, err := e.SumEndpoint(ctx, request) + if err != nil { + return 0, err + } + return response.(sumResponse).V, nil +} + +// Concat implements Service. Primarily useful in a client. +func (e Endpoints) Concat(ctx context.Context, a, b string) (string, error) { + request := concatRequest{A: a, B: b} + response, err := e.ConcatEndpoint(ctx, request) + if err != nil { + return "", err + } + return response.(concatResponse).V, err +} + +// MakeSumEndpoint returns an endpoint that invokes Sum on the service. +// Primarily useful in a server. +func MakeSumEndpoint(s Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (response interface{}, err error) { + sumReq := request.(sumRequest) + v, err := s.Sum(ctx, sumReq.A, sumReq.B) + if err != nil { + return nil, err + } + return sumResponse{ + V: v, + }, nil + } +} + +// MakeConcatEndpoint returns an endpoint that invokes Concat on the service. +// Primarily useful in a server. +func MakeConcatEndpoint(s Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (response interface{}, err error) { + concatReq := request.(concatRequest) + v, err := s.Concat(ctx, concatReq.A, concatReq.B) + if err != nil { + return nil, err + } + return concatResponse{ + V: v, + }, nil + } +} + +// EndpointInstrumentingMiddleware returns an endpoint middleware that records +// the duration of each invocation to the passed histogram. The middleware adds +// a single field: "success", which is "true" if no error is returned, and +// "false" otherwise. +func EndpointInstrumentingMiddleware(duration metrics.TimeHistogram) endpoint.Middleware { + return func(next endpoint.Endpoint) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (response interface{}, err error) { + + defer func(begin time.Time) { + f := metrics.Field{Key: "success", Value: fmt.Sprint(err == nil)} + duration.With(f).Observe(time.Since(begin)) + }(time.Now()) + return next(ctx, request) + + } + } +} + +// EndpointLoggingMiddleware returns an endpoint middleware that logs the +// duration of each invocation, and the resulting error, if any. +func EndpointLoggingMiddleware(logger log.Logger) endpoint.Middleware { + return func(next endpoint.Endpoint) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (response interface{}, err error) { + + defer func(begin time.Time) { + logger.Log("error", err, "took", time.Since(begin)) + }(time.Now()) + return next(ctx, request) + + } + } +} + +// These types are unexported because they only exist to serve the endpoint +// domain, which is totally encapsulated in this package. They are otherwise +// opaque to all callers. + +type sumRequest struct{ A, B int } + +type sumResponse struct{ V int } + +type concatRequest struct{ A, B string } + +type concatResponse struct{ V string } diff --git a/examples/addsvc/grpc_binding.go b/examples/addsvc/grpc_binding.go deleted file mode 100644 index 101c71f2c..000000000 --- a/examples/addsvc/grpc_binding.go +++ /dev/null @@ -1,47 +0,0 @@ -package main - -import ( - "golang.org/x/net/context" - - "github.com/opentracing/opentracing-go" - - "github.com/go-kit/kit/examples/addsvc/pb" - "github.com/go-kit/kit/examples/addsvc/server" - servergrpc "github.com/go-kit/kit/examples/addsvc/server/grpc" - "github.com/go-kit/kit/log" - kitot "github.com/go-kit/kit/tracing/opentracing" - "github.com/go-kit/kit/transport/grpc" -) - -type grpcBinding struct { - sum, concat grpc.Handler -} - -func newGRPCBinding(ctx context.Context, tracer opentracing.Tracer, svc server.AddService, tracingLogger log.Logger) grpcBinding { - return grpcBinding{ - sum: grpc.NewServer( - ctx, - kitot.TraceServer(tracer, "sum")(makeSumEndpoint(svc)), - servergrpc.DecodeSumRequest, - servergrpc.EncodeSumResponse, - grpc.ServerBefore(kitot.FromGRPCRequest(tracer, "", tracingLogger)), - ), - concat: grpc.NewServer( - ctx, - kitot.TraceServer(tracer, "concat")(makeConcatEndpoint(svc)), - servergrpc.DecodeConcatRequest, - servergrpc.EncodeConcatResponse, - grpc.ServerBefore(kitot.FromGRPCRequest(tracer, "", tracingLogger)), - ), - } -} - -func (b grpcBinding) Sum(ctx context.Context, req *pb.SumRequest) (*pb.SumReply, error) { - _, resp, err := b.sum.ServeGRPC(ctx, req) - return resp.(*pb.SumReply), err -} - -func (b grpcBinding) Concat(ctx context.Context, req *pb.ConcatRequest) (*pb.ConcatReply, error) { - _, resp, err := b.concat.ServeGRPC(ctx, req) - return resp.(*pb.ConcatReply), err -} diff --git a/examples/addsvc/main.go b/examples/addsvc/main.go deleted file mode 100644 index 17443cfdb..000000000 --- a/examples/addsvc/main.go +++ /dev/null @@ -1,257 +0,0 @@ -package main - -import ( - "flag" - "fmt" - stdlog "log" - "math/rand" - "net" - "net/http" - "net/rpc" - "os" - "os/signal" - "strings" - "syscall" - "time" - - "github.com/apache/thrift/lib/go/thrift" - "github.com/lightstep/lightstep-tracer-go" - "github.com/opentracing/opentracing-go" - zipkin "github.com/openzipkin/zipkin-go-opentracing" - stdprometheus "github.com/prometheus/client_golang/prometheus" - appdashot "github.com/sourcegraph/appdash/opentracing" - "golang.org/x/net/context" - "google.golang.org/grpc" - "sourcegraph.com/sourcegraph/appdash" - - "github.com/go-kit/kit/endpoint" - "github.com/go-kit/kit/examples/addsvc/pb" - "github.com/go-kit/kit/examples/addsvc/server" - thriftadd "github.com/go-kit/kit/examples/addsvc/thrift/gen-go/add" - "github.com/go-kit/kit/log" - "github.com/go-kit/kit/metrics" - "github.com/go-kit/kit/metrics/expvar" - "github.com/go-kit/kit/metrics/prometheus" - kitot "github.com/go-kit/kit/tracing/opentracing" - httptransport "github.com/go-kit/kit/transport/http" -) - -func main() { - // Flag domain. Note that gRPC transitively registers flags via its import - // of glog. So, we define a new flag set, to keep those domains distinct. - fs := flag.NewFlagSet("", flag.ExitOnError) - var ( - debugAddr = fs.String("debug.addr", ":8000", "Address for HTTP debug/instrumentation server") - httpAddr = fs.String("http.addr", ":8001", "Address for HTTP (JSON) server") - grpcAddr = fs.String("grpc.addr", ":8002", "Address for gRPC server") - netrpcAddr = fs.String("netrpc.addr", ":8003", "Address for net/rpc server") - thriftAddr = fs.String("thrift.addr", ":8004", "Address for Thrift server") - thriftProtocol = fs.String("thrift.protocol", "binary", "binary, compact, json, simplejson") - thriftBufferSize = fs.Int("thrift.buffer.size", 0, "0 for unbuffered") - thriftFramed = fs.Bool("thrift.framed", false, "true to enable framing") - - // Supported OpenTracing backends - zipkinAddr = fs.String("zipkin.kafka.addr", "", "Enable Zipkin tracing via a Kafka server host:port") - appdashAddr = fs.String("appdash.addr", "", "Enable Appdash tracing via an Appdash server host:port") - lightstepAccessToken = fs.String("lightstep.token", "", "Enable LightStep tracing via a LightStep access token") - ) - flag.Usage = fs.Usage // only show our flags - if err := fs.Parse(os.Args[1:]); err != nil { - fmt.Fprintf(os.Stderr, "%v", err) - os.Exit(1) - } - - // package log - var logger log.Logger - { - logger = log.NewLogfmtLogger(os.Stderr) - logger = log.NewContext(logger).With("ts", log.DefaultTimestampUTC).With("caller", log.DefaultCaller) - stdlog.SetFlags(0) // flags are handled by Go kit's logger - stdlog.SetOutput(log.NewStdlibAdapter(logger)) // redirect anything using stdlib log to us - } - - // package metrics - var requestDuration metrics.TimeHistogram - { - requestDuration = metrics.NewTimeHistogram(time.Nanosecond, metrics.NewMultiHistogram( - "request_duration_ns", - expvar.NewHistogram("request_duration_ns", 0, 5e9, 1, 50, 95, 99), - prometheus.NewSummary(stdprometheus.SummaryOpts{ - Namespace: "myorg", - Subsystem: "addsvc", - Name: "duration_ns", - Help: "Request duration in nanoseconds.", - }, []string{"method"}), - )) - } - - // Set up OpenTracing - var tracer opentracing.Tracer - { - switch { - case *appdashAddr != "" && *lightstepAccessToken == "" && *zipkinAddr == "": - tracer = appdashot.NewTracer(appdash.NewRemoteCollector(*appdashAddr)) - case *appdashAddr == "" && *lightstepAccessToken != "" && *zipkinAddr == "": - tracer = lightstep.NewTracer(lightstep.Options{ - AccessToken: *lightstepAccessToken, - }) - defer lightstep.FlushLightStepTracer(tracer) - case *appdashAddr == "" && *lightstepAccessToken == "" && *zipkinAddr != "": - collector, err := zipkin.NewKafkaCollector( - strings.Split(*zipkinAddr, ","), - zipkin.KafkaLogger(logger), - ) - if err != nil { - logger.Log("err", "unable to create collector", "fatal", err) - os.Exit(1) - } - tracer, err = zipkin.NewTracer( - zipkin.NewRecorder(collector, false, "localhost:80", "addsvc"), - ) - if err != nil { - logger.Log("err", "unable to create zipkin tracer", "fatal", err) - os.Exit(1) - } - case *appdashAddr == "" && *lightstepAccessToken == "" && *zipkinAddr == "": - tracer = opentracing.GlobalTracer() // no-op - default: - logger.Log("fatal", "specify a single -appdash.addr, -lightstep.access.token or -zipkin.kafka.addr") - os.Exit(1) - } - } - - // Business domain - var svc server.AddService - { - svc = pureAddService{} - svc = loggingMiddleware{svc, logger} - svc = instrumentingMiddleware{svc, requestDuration} - } - - // Mechanical stuff - rand.Seed(time.Now().UnixNano()) - root := context.Background() - errc := make(chan error) - - go func() { - errc <- interrupt() - }() - - // Debug/instrumentation - go func() { - transportLogger := log.NewContext(logger).With("transport", "debug") - transportLogger.Log("addr", *debugAddr) - errc <- http.ListenAndServe(*debugAddr, nil) // DefaultServeMux - }() - - // Transport: HTTP/JSON - go func() { - var ( - transportLogger = log.NewContext(logger).With("transport", "HTTP/JSON") - tracingLogger = log.NewContext(transportLogger).With("component", "tracing") - mux = http.NewServeMux() - sum, concat endpoint.Endpoint - ) - - sum = makeSumEndpoint(svc) - sum = kitot.TraceServer(tracer, "sum")(sum) - mux.Handle("/sum", httptransport.NewServer( - root, - sum, - server.DecodeSumRequest, - server.EncodeSumResponse, - httptransport.ServerErrorLogger(transportLogger), - httptransport.ServerBefore(kitot.FromHTTPRequest(tracer, "sum", tracingLogger)), - )) - - concat = makeConcatEndpoint(svc) - concat = kitot.TraceServer(tracer, "concat")(concat) - mux.Handle("/concat", httptransport.NewServer( - root, - concat, - server.DecodeConcatRequest, - server.EncodeConcatResponse, - httptransport.ServerErrorLogger(transportLogger), - httptransport.ServerBefore(kitot.FromHTTPRequest(tracer, "concat", tracingLogger)), - )) - - transportLogger.Log("addr", *httpAddr) - errc <- http.ListenAndServe(*httpAddr, mux) - }() - - // Transport: gRPC - go func() { - transportLogger := log.NewContext(logger).With("transport", "gRPC") - tracingLogger := log.NewContext(transportLogger).With("component", "tracing") - ln, err := net.Listen("tcp", *grpcAddr) - if err != nil { - errc <- err - return - } - s := grpc.NewServer() // uses its own, internal context - pb.RegisterAddServer(s, newGRPCBinding(root, tracer, svc, tracingLogger)) - transportLogger.Log("addr", *grpcAddr) - errc <- s.Serve(ln) - }() - - // Transport: net/rpc - go func() { - transportLogger := log.NewContext(logger).With("transport", "net/rpc") - s := rpc.NewServer() - if err := s.RegisterName("addsvc", netrpcBinding{svc}); err != nil { - errc <- err - return - } - s.HandleHTTP(rpc.DefaultRPCPath, rpc.DefaultDebugPath) - transportLogger.Log("addr", *netrpcAddr) - errc <- http.ListenAndServe(*netrpcAddr, s) - }() - - // Transport: Thrift - go func() { - var protocolFactory thrift.TProtocolFactory - switch *thriftProtocol { - case "binary": - protocolFactory = thrift.NewTBinaryProtocolFactoryDefault() - case "compact": - protocolFactory = thrift.NewTCompactProtocolFactory() - case "json": - protocolFactory = thrift.NewTJSONProtocolFactory() - case "simplejson": - protocolFactory = thrift.NewTSimpleJSONProtocolFactory() - default: - errc <- fmt.Errorf("invalid Thrift protocol %q", *thriftProtocol) - return - } - var transportFactory thrift.TTransportFactory - if *thriftBufferSize > 0 { - transportFactory = thrift.NewTBufferedTransportFactory(*thriftBufferSize) - } else { - transportFactory = thrift.NewTTransportFactory() - } - if *thriftFramed { - transportFactory = thrift.NewTFramedTransportFactory(transportFactory) - } - transport, err := thrift.NewTServerSocket(*thriftAddr) - if err != nil { - errc <- err - return - } - transportLogger := log.NewContext(logger).With("transport", "thrift") - transportLogger.Log("addr", *thriftAddr) - errc <- thrift.NewTSimpleServer4( - thriftadd.NewAddServiceProcessor(thriftBinding{svc}), - transport, - transportFactory, - protocolFactory, - ).Serve() - }() - - logger.Log("fatal", <-errc) -} - -func interrupt() error { - c := make(chan os.Signal) - signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) - return fmt.Errorf("%s", <-c) -} diff --git a/examples/addsvc/netrpc_binding.go b/examples/addsvc/netrpc_binding.go deleted file mode 100644 index 6171a8aaa..000000000 --- a/examples/addsvc/netrpc_binding.go +++ /dev/null @@ -1,21 +0,0 @@ -package main - -import ( - "github.com/go-kit/kit/examples/addsvc/server" -) - -type netrpcBinding struct { - server.AddService -} - -func (b netrpcBinding) Sum(request server.SumRequest, response *server.SumResponse) error { - v := b.AddService.Sum(request.A, request.B) - (*response) = server.SumResponse{V: v} - return nil -} - -func (b netrpcBinding) Concat(request server.ConcatRequest, response *server.ConcatResponse) error { - v := b.AddService.Concat(request.A, request.B) - (*response) = server.ConcatResponse{V: v} - return nil -} diff --git a/examples/addsvc/pb/add.pb.go b/examples/addsvc/pb/addsvc.pb.go similarity index 84% rename from examples/addsvc/pb/add.pb.go rename to examples/addsvc/pb/addsvc.pb.go index 1ac0c8e58..c766982e7 100644 --- a/examples/addsvc/pb/add.pb.go +++ b/examples/addsvc/pb/addsvc.pb.go @@ -1,12 +1,12 @@ // Code generated by protoc-gen-go. -// source: add.proto +// source: addsvc.proto // DO NOT EDIT! /* Package pb is a generated protocol buffer package. It is generated from these files: - add.proto + addsvc.proto It has these top-level messages: SumRequest @@ -192,16 +192,16 @@ var _Add_serviceDesc = grpc.ServiceDesc{ } var fileDescriptor0 = []byte{ - // 171 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xe2, 0x4c, 0x4c, 0x49, 0xd1, - 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x2a, 0x48, 0x52, 0xd2, 0xe0, 0xe2, 0x0a, 0x2e, 0xcd, - 0x0d, 0x4a, 0x2d, 0x2c, 0x4d, 0x2d, 0x2e, 0x11, 0xe2, 0xe1, 0x62, 0x4c, 0x94, 0x60, 0x54, 0x60, - 0xd4, 0x60, 0x0e, 0x62, 0x4c, 0x04, 0xf1, 0x92, 0x24, 0x98, 0x20, 0xbc, 0x24, 0x25, 0x09, 0x2e, - 0x0e, 0xb0, 0xca, 0x82, 0x9c, 0x4a, 0x90, 0x4c, 0x19, 0x4c, 0x5d, 0x99, 0x92, 0x36, 0x17, 0xaf, - 0x73, 0x7e, 0x5e, 0x72, 0x62, 0x09, 0x86, 0x31, 0x9c, 0x28, 0xc6, 0x70, 0x82, 0x8c, 0x91, 0xe6, - 0xe2, 0x86, 0x29, 0x46, 0x31, 0x09, 0x28, 0x59, 0x66, 0x14, 0xc3, 0xc5, 0xec, 0x98, 0x92, 0x22, - 0xa4, 0xca, 0xc5, 0x0c, 0xb4, 0x4a, 0x88, 0x4f, 0xaf, 0x20, 0x49, 0x0f, 0xe1, 0x3a, 0x29, 0x1e, - 0x38, 0x1f, 0xa8, 0x53, 0x89, 0x41, 0x48, 0x8f, 0x8b, 0x0d, 0x62, 0x94, 0x90, 0x20, 0x48, 0x06, - 0xc5, 0x0d, 0x52, 0xfc, 0xc8, 0x42, 0x60, 0xf5, 0x49, 0x6c, 0x60, 0x6f, 0x1b, 0x03, 0x02, 0x00, - 0x00, 0xff, 0xff, 0xb4, 0xc9, 0xe7, 0x58, 0x03, 0x01, 0x00, 0x00, + // 174 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xe2, 0x49, 0x4c, 0x49, 0x29, + 0x2e, 0x4b, 0xd6, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x2a, 0x48, 0x52, 0xd2, 0xe0, 0xe2, + 0x0a, 0x2e, 0xcd, 0x0d, 0x4a, 0x2d, 0x2c, 0x4d, 0x2d, 0x2e, 0x11, 0xe2, 0xe1, 0x62, 0x4c, 0x94, + 0x60, 0x54, 0x60, 0xd4, 0x60, 0x0e, 0x62, 0x4c, 0x04, 0xf1, 0x92, 0x24, 0x98, 0x20, 0xbc, 0x24, + 0x25, 0x09, 0x2e, 0x0e, 0xb0, 0xca, 0x82, 0x9c, 0x4a, 0x90, 0x4c, 0x19, 0x4c, 0x5d, 0x99, 0x92, + 0x36, 0x17, 0xaf, 0x73, 0x7e, 0x5e, 0x72, 0x62, 0x09, 0x86, 0x31, 0x9c, 0x28, 0xc6, 0x70, 0x82, + 0x8c, 0x91, 0xe6, 0xe2, 0x86, 0x29, 0x46, 0x31, 0x09, 0x28, 0x59, 0x66, 0x14, 0xc3, 0xc5, 0xec, + 0x98, 0x92, 0x22, 0xa4, 0xca, 0xc5, 0x0c, 0xb4, 0x4a, 0x88, 0x4f, 0xaf, 0x20, 0x49, 0x0f, 0xe1, + 0x3a, 0x29, 0x1e, 0x38, 0x1f, 0xa8, 0x53, 0x89, 0x41, 0x48, 0x8f, 0x8b, 0x0d, 0x62, 0x94, 0x90, + 0x20, 0x48, 0x06, 0xc5, 0x0d, 0x52, 0xfc, 0xc8, 0x42, 0x60, 0xf5, 0x49, 0x6c, 0x60, 0x6f, 0x1b, + 0x03, 0x02, 0x00, 0x00, 0xff, 0xff, 0x8b, 0x2c, 0x12, 0xb4, 0x06, 0x01, 0x00, 0x00, } diff --git a/examples/addsvc/pb/add.proto b/examples/addsvc/pb/addsvc.proto similarity index 100% rename from examples/addsvc/pb/add.proto rename to examples/addsvc/pb/addsvc.proto diff --git a/examples/addsvc/pb/compile.sh b/examples/addsvc/pb/compile.sh index e23129229..c0268442a 100755 --- a/examples/addsvc/pb/compile.sh +++ b/examples/addsvc/pb/compile.sh @@ -11,4 +11,4 @@ # See also # https://github.com/grpc/grpc-go/tree/master/examples -protoc add.proto --go_out=plugins=grpc:. +protoc addsvc.proto --go_out=plugins=grpc:. diff --git a/examples/addsvc/server/encode_decode.go b/examples/addsvc/server/encode_decode.go deleted file mode 100644 index bf3e80346..000000000 --- a/examples/addsvc/server/encode_decode.go +++ /dev/null @@ -1,84 +0,0 @@ -package server - -import ( - "bytes" - "encoding/json" - "io/ioutil" - "net/http" - - "golang.org/x/net/context" -) - -// DecodeSumRequest decodes the request from the provided HTTP request, simply -// by JSON decoding from the request body. It's designed to be used in -// transport/http.Server. -func DecodeSumRequest(_ context.Context, r *http.Request) (interface{}, error) { - var request SumRequest - err := json.NewDecoder(r.Body).Decode(&request) - return &request, err -} - -// EncodeSumResponse encodes the response to the provided HTTP response -// writer, simply by JSON encoding to the writer. It's designed to be used in -// transport/http.Server. -func EncodeSumResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { - return json.NewEncoder(w).Encode(response) -} - -// DecodeConcatRequest decodes the request from the provided HTTP request, -// simply by JSON decoding from the request body. It's designed to be used in -// transport/http.Server. -func DecodeConcatRequest(_ context.Context, r *http.Request) (interface{}, error) { - var request ConcatRequest - err := json.NewDecoder(r.Body).Decode(&request) - return &request, err -} - -// EncodeConcatResponse encodes the response to the provided HTTP response -// writer, simply by JSON encoding to the writer. It's designed to be used in -// transport/http.Server. -func EncodeConcatResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { - return json.NewEncoder(w).Encode(response) -} - -// EncodeSumRequest encodes the request to the provided HTTP request, simply -// by JSON encoding to the request body. It's designed to be used in -// transport/http.Client. -func EncodeSumRequest(_ context.Context, r *http.Request, request interface{}) error { - var buf bytes.Buffer - if err := json.NewEncoder(&buf).Encode(request); err != nil { - return err - } - r.Body = ioutil.NopCloser(&buf) - return nil -} - -// DecodeSumResponse decodes the response from the provided HTTP response, -// simply by JSON decoding from the response body. It's designed to be used in -// transport/http.Client. -func DecodeSumResponse(_ context.Context, resp *http.Response) (interface{}, error) { - var response SumResponse - err := json.NewDecoder(resp.Body).Decode(&response) - return response, err -} - -// EncodeConcatRequest encodes the request to the provided HTTP request, -// simply by JSON encoding to the request body. It's designed to be used in -// transport/http.Client. -func EncodeConcatRequest(_ context.Context, r *http.Request, request interface{}) error { - var buf bytes.Buffer - if err := json.NewEncoder(&buf).Encode(request); err != nil { - return err - } - r.Body = ioutil.NopCloser(&buf) - return nil -} - -// DecodeConcatResponse decodes the response from the provided HTTP response, -// simply by JSON decoding from the response body. It's designed to be used in -// transport/http.Client. -func DecodeConcatResponse(_ context.Context, resp *http.Response) (interface{}, error) { - var response ConcatResponse - err := json.NewDecoder(resp.Body).Decode(&response) - return response, err -} diff --git a/examples/addsvc/server/grpc/encode_decode.go b/examples/addsvc/server/grpc/encode_decode.go deleted file mode 100644 index 05a763490..000000000 --- a/examples/addsvc/server/grpc/encode_decode.go +++ /dev/null @@ -1,42 +0,0 @@ -package grpc - -import ( - "golang.org/x/net/context" - - "github.com/go-kit/kit/examples/addsvc/pb" - "github.com/go-kit/kit/examples/addsvc/server" -) - -func DecodeSumRequest(ctx context.Context, req interface{}) (interface{}, error) { - sumRequest := req.(*pb.SumRequest) - - return &server.SumRequest{ - A: int(sumRequest.A), - B: int(sumRequest.B), - }, nil -} - -func DecodeConcatRequest(ctx context.Context, req interface{}) (interface{}, error) { - concatRequest := req.(*pb.ConcatRequest) - - return &server.ConcatRequest{ - A: concatRequest.A, - B: concatRequest.B, - }, nil -} - -func EncodeSumResponse(ctx context.Context, resp interface{}) (interface{}, error) { - domainResponse := resp.(server.SumResponse) - - return &pb.SumReply{ - V: int64(domainResponse.V), - }, nil -} - -func EncodeConcatResponse(ctx context.Context, resp interface{}) (interface{}, error) { - domainResponse := resp.(server.ConcatResponse) - - return &pb.ConcatReply{ - V: domainResponse.V, - }, nil -} diff --git a/examples/addsvc/server/request_response.go b/examples/addsvc/server/request_response.go deleted file mode 100644 index c01b249a3..000000000 --- a/examples/addsvc/server/request_response.go +++ /dev/null @@ -1,23 +0,0 @@ -package server - -// SumRequest is the business domain type for a Sum method request. -type SumRequest struct { - A int `json:"a"` - B int `json:"b"` -} - -// SumResponse is the business domain type for a Sum method response. -type SumResponse struct { - V int `json:"v"` -} - -// ConcatRequest is the business domain type for a Concat method request. -type ConcatRequest struct { - A string `json:"a"` - B string `json:"b"` -} - -// ConcatResponse is the business domain type for a Concat method response. -type ConcatResponse struct { - V string `json:"v"` -} diff --git a/examples/addsvc/server/service.go b/examples/addsvc/server/service.go deleted file mode 100644 index ecc63687d..000000000 --- a/examples/addsvc/server/service.go +++ /dev/null @@ -1,7 +0,0 @@ -package server - -// AddService is the abstract representation of this service. -type AddService interface { - Sum(a, b int) int - Concat(a, b string) string -} diff --git a/examples/addsvc/service.go b/examples/addsvc/service.go index f0531a267..925c08b84 100644 --- a/examples/addsvc/service.go +++ b/examples/addsvc/service.go @@ -1,71 +1,135 @@ -package main +package addsvc + +// This file contains the Service definition, and a basic service +// implementation. It also includes service middlewares. import ( + "errors" "time" - "github.com/go-kit/kit/examples/addsvc/server" + "golang.org/x/net/context" + "github.com/go-kit/kit/log" "github.com/go-kit/kit/metrics" ) -type pureAddService struct{} +// Service describes a service that adds things together. +type Service interface { + Sum(ctx context.Context, a, b int) (int, error) + Concat(ctx context.Context, a, b string) (string, error) +} + +var ( + // ErrTwoZeroes is an arbitrary business rule for the Add method. + ErrTwoZeroes = errors.New("can't sum two zeroes") + + // ErrIntOverflow protects the Add method. + ErrIntOverflow = errors.New("integer overflow") + + // ErrMaxSizeExceeded protects the Concat method. + ErrMaxSizeExceeded = errors.New("result exceeds maximum size") +) + +// NewBasicService returns a naïve, stateless implementation of Service. +func NewBasicService() Service { + return basicService{} +} + +type basicService struct{} -func (pureAddService) Sum(a, b int) int { return a + b } +const ( + intMax = 1<<31 - 1 + intMin = -(intMax + 1) + maxLen = 102400 +) -func (pureAddService) Concat(a, b string) string { return a + b } +// Sum implements Service. +func (s basicService) Sum(_ context.Context, a, b int) (int, error) { + if a == 0 && b == 0 { + return 0, ErrTwoZeroes + } + if (b > 0 && a > (intMax-b)) || (b < 0 && a < (intMin-b)) { + return 0, ErrIntOverflow + } + return a + b, nil +} -type loggingMiddleware struct { - server.AddService - log.Logger +// Concat implements Service. +func (s basicService) Concat(_ context.Context, a, b string) (string, error) { + if len(a)+len(b) > maxLen { + return "", ErrMaxSizeExceeded + } + return a + b, nil } -func (m loggingMiddleware) Sum(a, b int) (v int) { +// Middleware describes a service (as opposed to endpoint) middleware. +type Middleware func(Service) Service + +// ServiceLoggingMiddleware returns a service middleware that logs the +// parameters and result of each method invocation. +func ServiceLoggingMiddleware(logger log.Logger) Middleware { + return func(next Service) Service { + return serviceLoggingMiddleware{ + logger: logger, + next: next, + } + } +} + +type serviceLoggingMiddleware struct { + logger log.Logger + next Service +} + +func (mw serviceLoggingMiddleware) Sum(ctx context.Context, a, b int) (v int, err error) { defer func(begin time.Time) { - m.Logger.Log( - "method", "sum", - "a", a, - "b", b, - "v", v, + mw.logger.Log( + "method", "Sum", + "a", a, "b", b, "result", v, "error", err, "took", time.Since(begin), ) }(time.Now()) - v = m.AddService.Sum(a, b) - return + return mw.next.Sum(ctx, a, b) } -func (m loggingMiddleware) Concat(a, b string) (v string) { +func (mw serviceLoggingMiddleware) Concat(ctx context.Context, a, b string) (v string, err error) { defer func(begin time.Time) { - m.Logger.Log( - "method", "concat", - "a", a, - "b", b, - "v", v, + mw.logger.Log( + "method", "Concat", + "a", a, "b", b, "result", v, "error", err, "took", time.Since(begin), ) }(time.Now()) - v = m.AddService.Concat(a, b) - return + return mw.next.Concat(ctx, a, b) } -type instrumentingMiddleware struct { - server.AddService - requestDuration metrics.TimeHistogram +// ServiceInstrumentingMiddleware returns a service middleware that instruments +// the number of integers summed and characters concatenated over the lifetime of +// the service. +func ServiceInstrumentingMiddleware(ints, chars metrics.Counter) Middleware { + return func(next Service) Service { + return serviceInstrumentingMiddleware{ + ints: ints, + chars: chars, + next: next, + } + } } -func (m instrumentingMiddleware) Sum(a, b int) (v int) { - defer func(begin time.Time) { - methodField := metrics.Field{Key: "method", Value: "sum"} - m.requestDuration.With(methodField).Observe(time.Since(begin)) - }(time.Now()) - v = m.AddService.Sum(a, b) - return +type serviceInstrumentingMiddleware struct { + ints metrics.Counter + chars metrics.Counter + next Service } -func (m instrumentingMiddleware) Concat(a, b string) (v string) { - defer func(begin time.Time) { - methodField := metrics.Field{Key: "method", Value: "concat"} - m.requestDuration.With(methodField).Observe(time.Since(begin)) - }(time.Now()) - v = m.AddService.Concat(a, b) - return +func (mw serviceInstrumentingMiddleware) Sum(ctx context.Context, a, b int) (int, error) { + v, err := mw.next.Sum(ctx, a, b) + mw.ints.Add(uint64(v)) + return v, err +} + +func (mw serviceInstrumentingMiddleware) Concat(ctx context.Context, a, b string) (string, error) { + v, err := mw.next.Concat(ctx, a, b) + mw.chars.Add(uint64(len(v))) + return v, err } diff --git a/examples/addsvc/thrift/add.thrift b/examples/addsvc/thrift/addsvc.thrift similarity index 100% rename from examples/addsvc/thrift/add.thrift rename to examples/addsvc/thrift/addsvc.thrift diff --git a/examples/addsvc/thrift/compile.sh b/examples/addsvc/thrift/compile.sh index 11354b193..2ecce5b29 100755 --- a/examples/addsvc/thrift/compile.sh +++ b/examples/addsvc/thrift/compile.sh @@ -2,4 +2,4 @@ # See also https://thrift.apache.org/tutorial/go -thrift -r --gen "go:package_prefix=github.com/go-kit/kit/examples/addsvc/thrift/gen-go/,thrift_import=github.com/apache/thrift/lib/go/thrift" add.thrift +thrift -r --gen "go:package_prefix=github.com/go-kit/kit/examples/addsvc/thrift/gen-go/,thrift_import=github.com/apache/thrift/lib/go/thrift" addsvc.thrift diff --git a/examples/addsvc/thrift/gen-go/add/add_service-remote/add_service-remote.go b/examples/addsvc/thrift/gen-go/addsvc/add_service-remote/add_service-remote.go similarity index 96% rename from examples/addsvc/thrift/gen-go/add/add_service-remote/add_service-remote.go rename to examples/addsvc/thrift/gen-go/addsvc/add_service-remote/add_service-remote.go index 79606cb2c..b8ce67ca2 100755 --- a/examples/addsvc/thrift/gen-go/add/add_service-remote/add_service-remote.go +++ b/examples/addsvc/thrift/gen-go/addsvc/add_service-remote/add_service-remote.go @@ -7,7 +7,7 @@ import ( "flag" "fmt" "github.com/apache/thrift/lib/go/thrift" - "github.com/go-kit/kit/examples/addsvc/thrift/gen-go/add" + "github.com/go-kit/kit/examples/addsvc/thrift/gen-go/addsvc" "math" "net" "net/url" @@ -109,7 +109,7 @@ func main() { Usage() os.Exit(1) } - client := add.NewAddServiceClientFactory(trans, protocolFactory) + client := addsvc.NewAddServiceClientFactory(trans, protocolFactory) if err := trans.Open(); err != nil { fmt.Fprintln(os.Stderr, "Error opening socket to ", host, ":", port, " ", err) os.Exit(1) diff --git a/examples/addsvc/thrift/gen-go/add/addservice.go b/examples/addsvc/thrift/gen-go/addsvc/addservice.go similarity index 99% rename from examples/addsvc/thrift/gen-go/add/addservice.go rename to examples/addsvc/thrift/gen-go/addsvc/addservice.go index dbeb1191c..3f3aeebf1 100644 --- a/examples/addsvc/thrift/gen-go/add/addservice.go +++ b/examples/addsvc/thrift/gen-go/addsvc/addservice.go @@ -1,7 +1,7 @@ // Autogenerated by Thrift Compiler (0.9.3) // DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING -package add +package addsvc import ( "bytes" diff --git a/examples/addsvc/thrift/gen-go/add/constants.go b/examples/addsvc/thrift/gen-go/addsvc/constants.go similarity index 95% rename from examples/addsvc/thrift/gen-go/add/constants.go rename to examples/addsvc/thrift/gen-go/addsvc/constants.go index 64a325807..2f0079acc 100644 --- a/examples/addsvc/thrift/gen-go/add/constants.go +++ b/examples/addsvc/thrift/gen-go/addsvc/constants.go @@ -1,7 +1,7 @@ // Autogenerated by Thrift Compiler (0.9.3) // DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING -package add +package addsvc import ( "bytes" diff --git a/examples/addsvc/thrift/gen-go/add/ttypes.go b/examples/addsvc/thrift/gen-go/addsvc/ttypes.go similarity index 99% rename from examples/addsvc/thrift/gen-go/add/ttypes.go rename to examples/addsvc/thrift/gen-go/addsvc/ttypes.go index 744abbaf6..bbae46c5c 100644 --- a/examples/addsvc/thrift/gen-go/add/ttypes.go +++ b/examples/addsvc/thrift/gen-go/addsvc/ttypes.go @@ -1,7 +1,7 @@ // Autogenerated by Thrift Compiler (0.9.3) // DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING -package add +package addsvc import ( "bytes" diff --git a/examples/addsvc/thrift_binding.go b/examples/addsvc/thrift_binding.go deleted file mode 100644 index 959fcad29..000000000 --- a/examples/addsvc/thrift_binding.go +++ /dev/null @@ -1,20 +0,0 @@ -package main - -import ( - "github.com/go-kit/kit/examples/addsvc/server" - thriftadd "github.com/go-kit/kit/examples/addsvc/thrift/gen-go/add" -) - -type thriftBinding struct { - server.AddService -} - -func (tb thriftBinding) Sum(a, b int64) (*thriftadd.SumReply, error) { - v := tb.AddService.Sum(int(a), int(b)) - return &thriftadd.SumReply{Value: int64(v)}, nil -} - -func (tb thriftBinding) Concat(a, b string) (*thriftadd.ConcatReply, error) { - v := tb.AddService.Concat(a, b) - return &thriftadd.ConcatReply{Value: v}, nil -} diff --git a/examples/addsvc/transport_grpc.go b/examples/addsvc/transport_grpc.go new file mode 100644 index 000000000..dc45e4de8 --- /dev/null +++ b/examples/addsvc/transport_grpc.go @@ -0,0 +1,112 @@ +package addsvc + +// This file provides server-side bindings for the gRPC transport. +// It utilizes the transport/grpc.Server. + +import ( + stdopentracing "github.com/opentracing/opentracing-go" + "golang.org/x/net/context" + + "github.com/go-kit/kit/examples/addsvc/pb" + "github.com/go-kit/kit/log" + "github.com/go-kit/kit/tracing/opentracing" + grpctransport "github.com/go-kit/kit/transport/grpc" +) + +// MakeGRPCServer makes a set of endpoints available as a gRPC AddServer. +func MakeGRPCServer(ctx context.Context, endpoints Endpoints, tracer stdopentracing.Tracer, logger log.Logger) pb.AddServer { + options := []grpctransport.ServerOption{ + grpctransport.ServerErrorLogger(logger), + } + return &grpcServer{ + sum: grpctransport.NewServer( + ctx, + endpoints.SumEndpoint, + DecodeGRPCSumRequest, + EncodeGRPCSumResponse, + append(options, grpctransport.ServerBefore(opentracing.FromGRPCRequest(tracer, "Sum", logger)))..., + ), + concat: grpctransport.NewServer( + ctx, + endpoints.ConcatEndpoint, + DecodeGRPCConcatRequest, + EncodeGRPCConcatResponse, + append(options, grpctransport.ServerBefore(opentracing.FromGRPCRequest(tracer, "Concat", logger)))..., + ), + } +} + +type grpcServer struct { + sum grpctransport.Handler + concat grpctransport.Handler +} + +func (s *grpcServer) Sum(ctx context.Context, req *pb.SumRequest) (*pb.SumReply, error) { + _, rep, err := s.sum.ServeGRPC(ctx, req) + return rep.(*pb.SumReply), err +} + +func (s *grpcServer) Concat(ctx context.Context, req *pb.ConcatRequest) (*pb.ConcatReply, error) { + _, rep, err := s.concat.ServeGRPC(ctx, req) + return rep.(*pb.ConcatReply), err +} + +// DecodeGRPCSumRequest is a transport/grpc.DecodeRequestFunc that converts a +// gRPC sum request to a user-domain sum request. Primarily useful in a server. +func DecodeGRPCSumRequest(_ context.Context, grpcReq interface{}) (interface{}, error) { + req := grpcReq.(*pb.SumRequest) + return sumRequest{A: int(req.A), B: int(req.B)}, nil +} + +// DecodeGRPCConcatRequest is a transport/grpc.DecodeRequestFunc that converts a +// gRPC concat request to a user-domain concat request. Primarily useful in a +// server. +func DecodeGRPCConcatRequest(_ context.Context, grpcReq interface{}) (interface{}, error) { + req := grpcReq.(*pb.ConcatRequest) + return concatRequest{A: req.A, B: req.B}, nil +} + +// DecodeGRPCSumResponse is a transport/grpc.DecodeResponseFunc that converts a +// gRPC sum reply to a user-domain sum response. Primarily useful in a client. +func DecodeGRPCSumResponse(_ context.Context, grpcReply interface{}) (interface{}, error) { + reply := grpcReply.(*pb.SumReply) + return sumResponse{V: int(reply.V)}, nil +} + +// DecodeGRPCConcatResponse is a transport/grpc.DecodeResponseFunc that converts +// a gRPC concat reply to a user-domain concat response. Primarily useful in a +// client. +func DecodeGRPCConcatResponse(_ context.Context, grpcReply interface{}) (interface{}, error) { + reply := grpcReply.(*pb.ConcatReply) + return concatResponse{V: reply.V}, nil +} + +// EncodeGRPCSumResponse is a transport/grpc.EncodeResponseFunc that converts a +// user-domain sum response to a gRPC sum reply. Primarily useful in a server. +func EncodeGRPCSumResponse(_ context.Context, response interface{}) (interface{}, error) { + resp := response.(sumResponse) + return &pb.SumReply{V: int64(resp.V)}, nil +} + +// EncodeGRPCConcatResponse is a transport/grpc.EncodeResponseFunc that converts +// a user-domain concat response to a gRPC concat reply. Primarily useful in a +// server. +func EncodeGRPCConcatResponse(_ context.Context, response interface{}) (interface{}, error) { + resp := response.(concatResponse) + return &pb.ConcatReply{V: resp.V}, nil +} + +// EncodeGRPCSumRequest is a transport/grpc.EncodeRequestFunc that converts a +// user-domain sum request to a gRPC sum request. Primarily useful in a client. +func EncodeGRPCSumRequest(_ context.Context, request interface{}) (interface{}, error) { + req := request.(sumRequest) + return &pb.SumRequest{A: int64(req.A), B: int64(req.B)}, nil +} + +// EncodeGRPCConcatRequest is a transport/grpc.EncodeRequestFunc that converts a +// user-domain concat request to a gRPC concat request. Primarily useful in a +// client. +func EncodeGRPCConcatRequest(_ context.Context, request interface{}) (interface{}, error) { + req := request.(concatRequest) + return &pb.ConcatRequest{A: req.A, B: req.B}, nil +} diff --git a/examples/addsvc/transport_http.go b/examples/addsvc/transport_http.go new file mode 100644 index 000000000..e2d8f6d64 --- /dev/null +++ b/examples/addsvc/transport_http.go @@ -0,0 +1,141 @@ +package addsvc + +// This file provides server-side bindings for the HTTP transport. +// It utilizes the transport/http.Server. + +import ( + "bytes" + "encoding/json" + "errors" + "io/ioutil" + "net/http" + + stdopentracing "github.com/opentracing/opentracing-go" + "golang.org/x/net/context" + + "github.com/go-kit/kit/log" + "github.com/go-kit/kit/tracing/opentracing" + httptransport "github.com/go-kit/kit/transport/http" +) + +// MakeHTTPHandler returns a handler that makes a set of endpoints available +// on predefined paths. +func MakeHTTPHandler(ctx context.Context, endpoints Endpoints, tracer stdopentracing.Tracer, logger log.Logger) http.Handler { + options := []httptransport.ServerOption{ + httptransport.ServerErrorEncoder(errorEncoder), + httptransport.ServerErrorLogger(logger), + } + m := http.NewServeMux() + m.Handle("/sum", httptransport.NewServer( + ctx, + endpoints.SumEndpoint, + DecodeHTTPSumRequest, + EncodeHTTPGenericResponse, + append(options, httptransport.ServerBefore(opentracing.FromHTTPRequest(tracer, "Sum", logger)))..., + )) + m.Handle("/concat", httptransport.NewServer( + ctx, + endpoints.ConcatEndpoint, + DecodeHTTPConcatRequest, + EncodeHTTPGenericResponse, + append(options, httptransport.ServerBefore(opentracing.FromHTTPRequest(tracer, "Concat", logger)))..., + )) + return m +} + +func errorEncoder(_ context.Context, err error, w http.ResponseWriter) { + code := http.StatusInternalServerError + msg := err.Error() + + if e, ok := err.(httptransport.Error); ok { + msg = e.Err.Error() + switch e.Domain { + case httptransport.DomainDecode: + code = http.StatusBadRequest + + case httptransport.DomainDo: + switch e.Err { + case ErrTwoZeroes, ErrMaxSizeExceeded, ErrIntOverflow: + code = http.StatusBadRequest + } + } + } + + w.WriteHeader(code) + json.NewEncoder(w).Encode(errorWrapper{Error: msg}) +} + +func errorDecoder(r *http.Response) error { + var w errorWrapper + if err := json.NewDecoder(r.Body).Decode(&w); err != nil { + return err + } + return errors.New(w.Error) +} + +type errorWrapper struct { + Error string `json:"error"` +} + +// DecodeHTTPSumRequest is a transport/http.DecodeRequestFunc that decodes a +// JSON-encoded sum request from the HTTP request body. Primarily useful in a +// server. +func DecodeHTTPSumRequest(_ context.Context, r *http.Request) (interface{}, error) { + var req sumRequest + err := json.NewDecoder(r.Body).Decode(&req) + return req, err +} + +// DecodeHTTPConcatRequest is a transport/http.DecodeRequestFunc that decodes a +// JSON-encoded concat request from the HTTP request body. Primarily useful in a +// server. +func DecodeHTTPConcatRequest(_ context.Context, r *http.Request) (interface{}, error) { + var req concatRequest + err := json.NewDecoder(r.Body).Decode(&req) + return req, err +} + +// DecodeHTTPSumResponse is a transport/http.DecodeResponseFunc that decodes a +// JSON-encoded sum response from the HTTP response body. If the response has a +// non-200 status code, we will interpret that as an error and attempt to decode +// the specific error message from the response body. Primarily useful in a +// client. +func DecodeHTTPSumResponse(_ context.Context, r *http.Response) (interface{}, error) { + if r.StatusCode != http.StatusOK { + return nil, errorDecoder(r) + } + var resp sumResponse + err := json.NewDecoder(r.Body).Decode(&resp) + return resp, err +} + +// DecodeHTTPConcatResponse is a transport/http.DecodeResponseFunc that decodes +// a JSON-encoded concat response from the HTTP response body. If the response +// has a non-200 status code, we will interpret that as an error and attempt to +// decode the specific error message from the response body. Primarily useful in +// a client. +func DecodeHTTPConcatResponse(_ context.Context, r *http.Response) (interface{}, error) { + if r.StatusCode != http.StatusOK { + return nil, errorDecoder(r) + } + var resp concatResponse + err := json.NewDecoder(r.Body).Decode(&resp) + return resp, err +} + +// EncodeHTTPGenericRequest is a transport/http.EncodeRequestFunc that +// JSON-encodes any request to the request body. Primarily useful in a client. +func EncodeHTTPGenericRequest(_ context.Context, r *http.Request, request interface{}) error { + var buf bytes.Buffer + if err := json.NewEncoder(&buf).Encode(request); err != nil { + return err + } + r.Body = ioutil.NopCloser(&buf) + return nil +} + +// EncodeHTTPGenericResponse is a transport/http.EncodeResponseFunc that encodes +// the response as JSON to the response writer. Primarily useful in a server. +func EncodeHTTPGenericResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { + return json.NewEncoder(w).Encode(response) +} diff --git a/examples/addsvc/transport_thrift.go b/examples/addsvc/transport_thrift.go new file mode 100644 index 000000000..3a44950cb --- /dev/null +++ b/examples/addsvc/transport_thrift.go @@ -0,0 +1,76 @@ +package addsvc + +// This file provides server-side bindings for the Thrift transport. +// +// This file also provides endpoint constructors that utilize a Thrift client, +// for use in client packages, because package transport/thrift doesn't exist +// yet. See https://github.com/go-kit/kit/issues/184. + +import ( + "golang.org/x/net/context" + + "github.com/go-kit/kit/endpoint" + thriftadd "github.com/go-kit/kit/examples/addsvc/thrift/gen-go/addsvc" +) + +// MakeThriftHandler makes a set of endpoints available as a Thrift service. +func MakeThriftHandler(ctx context.Context, e Endpoints) thriftadd.AddService { + return &thriftServer{ + ctx: ctx, + sum: e.SumEndpoint, + concat: e.ConcatEndpoint, + } +} + +type thriftServer struct { + ctx context.Context + sum endpoint.Endpoint + concat endpoint.Endpoint +} + +func (s *thriftServer) Sum(a int64, b int64) (*thriftadd.SumReply, error) { + request := sumRequest{A: int(a), B: int(b)} + response, err := s.sum(s.ctx, request) + if err != nil { + return nil, err + } + resp := response.(sumResponse) + return &thriftadd.SumReply{Value: int64(resp.V)}, nil +} + +func (s *thriftServer) Concat(a string, b string) (*thriftadd.ConcatReply, error) { + request := concatRequest{A: a, B: b} + response, err := s.concat(s.ctx, request) + if err != nil { + return nil, err + } + resp := response.(concatResponse) + return &thriftadd.ConcatReply{Value: resp.V}, nil +} + +// MakeThriftSumEndpoint returns an endpoint that invokes the passed Thrift client. +// Useful only in clients, and only until a proper transport/thrift.Client exists. +func MakeThriftSumEndpoint(client *thriftadd.AddServiceClient) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(sumRequest) + reply, err := client.Sum(int64(req.A), int64(req.B)) + if err != nil { + return nil, err + } + return sumResponse{V: int(reply.Value)}, nil + } +} + +// MakeThriftConcatEndpoint returns an endpoint that invokes the passed Thrift +// client. Useful only in clients, and only until a proper +// transport/thrift.Client exists. +func MakeThriftConcatEndpoint(client *thriftadd.AddServiceClient) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(concatRequest) + reply, err := client.Concat(req.A, req.B) + if err != nil { + return nil, err + } + return concatResponse{V: reply.Value}, nil + } +} diff --git a/examples/apigateway/main.go b/examples/apigateway/main.go index b01b71fa5..01367ec31 100644 --- a/examples/apigateway/main.go +++ b/examples/apigateway/main.go @@ -1,12 +1,12 @@ package main import ( + "bytes" "encoding/json" "flag" "fmt" "io" "io/ioutil" - stdlog "log" "net/http" "net/url" "os" @@ -17,16 +17,18 @@ import ( "github.com/gorilla/mux" "github.com/hashicorp/consul/api" - "github.com/opentracing/opentracing-go" + stdopentracing "github.com/opentracing/opentracing-go" "golang.org/x/net/context" "github.com/go-kit/kit/endpoint" - "github.com/go-kit/kit/examples/addsvc/client/grpc" - "github.com/go-kit/kit/examples/addsvc/server" - "github.com/go-kit/kit/loadbalancer" - "github.com/go-kit/kit/loadbalancer/consul" + "github.com/go-kit/kit/examples/addsvc" + addsvcgrpcclient "github.com/go-kit/kit/examples/addsvc/client/grpc" "github.com/go-kit/kit/log" + "github.com/go-kit/kit/sd" + consulsd "github.com/go-kit/kit/sd/consul" + "github.com/go-kit/kit/sd/lb" httptransport "github.com/go-kit/kit/transport/http" + "google.golang.org/grpc" ) func main() { @@ -38,153 +40,243 @@ func main() { ) flag.Parse() - // Log domain - logger := log.NewLogfmtLogger(os.Stderr) - logger = log.NewContext(logger).With("ts", log.DefaultTimestampUTC).With("caller", log.DefaultCaller) - stdlog.SetFlags(0) // flags are handled by Go kit's logger - stdlog.SetOutput(log.NewStdlibAdapter(logger)) // redirect anything using stdlib log to us + // Logging domain. + var logger log.Logger + { + logger = log.NewLogfmtLogger(os.Stderr) + logger = log.NewContext(logger).With("ts", log.DefaultTimestampUTC) + logger = log.NewContext(logger).With("caller", log.DefaultCaller) + } // Service discovery domain. In this example we use Consul. - consulConfig := api.DefaultConfig() - if len(*consulAddr) > 0 { - consulConfig.Address = *consulAddr - } - consulClient, err := api.NewClient(consulConfig) - if err != nil { - logger.Log("err", err) - os.Exit(1) + var client consulsd.Client + { + consulConfig := api.DefaultConfig() + if len(*consulAddr) > 0 { + consulConfig.Address = *consulAddr + } + consulClient, err := api.NewClient(consulConfig) + if err != nil { + logger.Log("err", err) + os.Exit(1) + } + client = consulsd.NewClient(consulClient) } - discoveryClient := consul.NewClient(consulClient) - // Context domain. + // Transport domain. + tracer := stdopentracing.GlobalTracer() // no-op ctx := context.Background() - - // Set up our routes. - // - // Each Consul service name maps to multiple instances of that service. We - // connect to each instance according to its pre-determined transport: in this - // case, we choose to access addsvc via its gRPC client, and stringsvc over - // plain transport/http (it has no client package). - // - // Each service instance implements multiple methods, and we want to map each - // method to a unique path on the API gateway. So, we define that path and its - // corresponding factory function, which takes an instance string and returns an - // endpoint.Endpoint for the specific method. - // - // Finally, we mount that path + endpoint handler into the router. r := mux.NewRouter() - for consulName, methods := range map[string][]struct { - path string - factory loadbalancer.Factory - }{ - "addsvc": { - {path: "/api/addsvc/concat", factory: grpc.MakeConcatEndpointFactory(opentracing.GlobalTracer(), nil)}, - {path: "/api/addsvc/sum", factory: grpc.MakeSumEndpointFactory(opentracing.GlobalTracer(), nil)}, - }, - "stringsvc": { - {path: "/api/stringsvc/uppercase", factory: httpFactory(ctx, "GET", "uppercase/")}, - {path: "/api/stringsvc/concat", factory: httpFactory(ctx, "GET", "concat/")}, - }, - } { - for _, method := range methods { - publisher, err := consul.NewPublisher(discoveryClient, method.factory, logger, consulName) - if err != nil { - logger.Log("service", consulName, "path", method.path, "err", err) - continue - } - lb := loadbalancer.NewRoundRobin(publisher) - e := loadbalancer.Retry(*retryMax, *retryTimeout, lb) - h := makeHandler(ctx, e, logger) - r.HandleFunc(method.path, h) + + // Now we begin installing the routes. Each route corresponds to a single + // method: sum, concat, uppercase, and count. + + // addsvc routes. + { + // Each method gets constructed with a factory. Factories take an + // instance string, and return a specific endpoint. In the factory we + // dial the instance string we get from Consul, and then leverage an + // addsvc client package to construct a complete service. We can then + // leverage the addsvc.Make{Sum,Concat}Endpoint constructors to convert + // the complete service to specific endpoint. + + var ( + tags = []string{} + passingOnly = true + endpoints = addsvc.Endpoints{} + ) + { + factory := addsvcFactory(addsvc.MakeSumEndpoint, tracer, logger) + subscriber := consulsd.NewSubscriber(client, factory, logger, "addsvc", tags, passingOnly) + balancer := lb.NewRoundRobin(subscriber) + retry := lb.Retry(*retryMax, *retryTimeout, balancer) + endpoints.SumEndpoint = retry + } + { + factory := addsvcFactory(addsvc.MakeConcatEndpoint, tracer, logger) + subscriber := consulsd.NewSubscriber(client, factory, logger, "addsvc", tags, passingOnly) + balancer := lb.NewRoundRobin(subscriber) + retry := lb.Retry(*retryMax, *retryTimeout, balancer) + endpoints.ConcatEndpoint = retry + } + + // Here we leverage the fact that addsvc comes with a constructor for an + // HTTP handler, and just install it under a particular path prefix in + // our router. + + r.PathPrefix("addsvc/").Handler(addsvc.MakeHTTPHandler(ctx, endpoints, tracer, logger)) + } + + // stringsvc routes. + { + // addsvc had lots of nice importable Go packages we could leverage. + // With stringsvc we are not so fortunate, it just has some endpoints + // that we assume will exist. So we have to write that logic here. This + // is by design, so you can see two totally different methods of + // proxying to a remote service. + + var ( + tags = []string{} + passingOnly = true + uppercase endpoint.Endpoint + count endpoint.Endpoint + ) + { + factory := stringsvcFactory(ctx, "GET", "/uppercase") + subscriber := consulsd.NewSubscriber(client, factory, logger, "stringsvc", tags, passingOnly) + balancer := lb.NewRoundRobin(subscriber) + retry := lb.Retry(*retryMax, *retryTimeout, balancer) + uppercase = retry + } + { + factory := stringsvcFactory(ctx, "GET", "/count") + subscriber := consulsd.NewSubscriber(client, factory, logger, "stringsvc", tags, passingOnly) + balancer := lb.NewRoundRobin(subscriber) + retry := lb.Retry(*retryMax, *retryTimeout, balancer) + count = retry } + + // We can use the transport/http.Server to act as our handler, all we + // have to do provide it with the encode and decode functions for our + // stringsvc methods. + + r.Handle("/stringsvc/uppercase", httptransport.NewServer(ctx, uppercase, decodeUppercaseRequest, encodeJSONResponse)) + r.Handle("/stringsvc/count", httptransport.NewServer(ctx, count, decodeCountRequest, encodeJSONResponse)) } - // Mechanical stuff. + // Interrupt handler. errc := make(chan error) go func() { - errc <- interrupt() + c := make(chan os.Signal) + signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) + errc <- fmt.Errorf("%s", <-c) }() + + // HTTP transport. go func() { - logger.Log("transport", "http", "addr", *httpAddr) + logger.Log("transport", "HTTP", "addr", *httpAddr) errc <- http.ListenAndServe(*httpAddr, r) }() - logger.Log("err", <-errc) + + // Run! + logger.Log("exit", <-errc) } -func makeHandler(ctx context.Context, e endpoint.Endpoint, logger log.Logger) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - resp, err := e(ctx, r.Body) - if err != nil { - logger.Log("err", err) - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - b, ok := resp.([]byte) - if !ok { - logger.Log("err", "endpoint response is not of type []byte") - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - _, err = w.Write(b) +func addsvcFactory(makeEndpoint func(addsvc.Service) endpoint.Endpoint, tracer stdopentracing.Tracer, logger log.Logger) sd.Factory { + return func(instance string) (endpoint.Endpoint, io.Closer, error) { + // We could just as easily use the HTTP or Thrift client package to make + // the connection to addsvc. We've chosen gRPC arbitrarily. Note that + // the transport is an implementation detail: it doesn't leak out of + // this function. Nice! + + conn, err := grpc.Dial(instance, grpc.WithInsecure()) if err != nil { - logger.Log("err", err) - return + return nil, nil, err } - } -} + service := addsvcgrpcclient.New(conn, tracer, logger) + endpoint := makeEndpoint(service) -func makeSumEndpoint(svc server.AddService) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - r := request.(io.Reader) - var req server.SumRequest - if err := json.NewDecoder(r).Decode(&req); err != nil { - return nil, err - } - v := svc.Sum(req.A, req.B) - return json.Marshal(v) - } -} + // Notice that the addsvc gRPC client converts the connection to a + // complete addsvc, and we just throw away everything except the method + // we're interested in. A smarter factory would mux multiple methods + // over the same connection. But that would require more work to manage + // the returned io.Closer, e.g. reference counting. Since this is for + // the purposes of demonstration, we'll just keep it simple. -func makeConcatEndpoint(svc server.AddService) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - r := request.(io.Reader) - var req server.ConcatRequest - if err := json.NewDecoder(r).Decode(&req); err != nil { - return nil, err - } - v := svc.Concat(req.A, req.B) - return json.Marshal(v) + return endpoint, conn, nil } } -func httpFactory(ctx context.Context, method, path string) loadbalancer.Factory { +func stringsvcFactory(ctx context.Context, method, path string) sd.Factory { return func(instance string) (endpoint.Endpoint, io.Closer, error) { - var e endpoint.Endpoint if !strings.HasPrefix(instance, "http") { instance = "http://" + instance } - u, err := url.Parse(instance) + tgt, err := url.Parse(instance) if err != nil { return nil, nil, err } - u.Path = path + tgt.Path = path + + // Since stringsvc doesn't have any kind of package we can import, or + // any formal spec, we are forced to just assert where the endpoints + // live, and write our own code to encode and decode requests and + // responses. Ideally, if you write the service, you will want to + // provide stronger guarantees to your clients. + + var ( + enc httptransport.EncodeRequestFunc + dec httptransport.DecodeResponseFunc + ) + switch path { + case "/uppercase": + enc, dec = encodeJSONRequest, decodeUppercaseResponse + case "/count": + enc, dec = encodeJSONRequest, decodeCountResponse + default: + return nil, nil, fmt.Errorf("unknown stringsvc path %q", path) + } - e = httptransport.NewClient(method, u, passEncode, passDecode).Endpoint() - return e, nil, nil + return httptransport.NewClient(method, tgt, enc, dec).Endpoint(), nil, nil } } -func passEncode(_ context.Context, r *http.Request, request interface{}) error { - r.Body = request.(io.ReadCloser) +func encodeJSONRequest(_ context.Context, req *http.Request, request interface{}) error { + // Both uppercase and count requests are encoded in the same way: + // simple JSON serialization to the request body. + var buf bytes.Buffer + if err := json.NewEncoder(&buf).Encode(request); err != nil { + return err + } + req.Body = ioutil.NopCloser(&buf) return nil } -func passDecode(_ context.Context, r *http.Response) (interface{}, error) { - return ioutil.ReadAll(r.Body) +func encodeJSONResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + return json.NewEncoder(w).Encode(response) } -func interrupt() error { - c := make(chan os.Signal) - signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) - return fmt.Errorf("%s", <-c) +// I've just copied these functions from stringsvc3/transport.go, inlining the +// struct definitions. + +func decodeUppercaseResponse(ctx context.Context, resp *http.Response) (interface{}, error) { + var response struct { + V string `json:"v"` + Err string `json:"err,omitempty"` + } + if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { + return nil, err + } + return response, nil +} + +func decodeCountResponse(ctx context.Context, resp *http.Response) (interface{}, error) { + var response struct { + V int `json:"v"` + } + if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { + return nil, err + } + return response, nil +} + +func decodeUppercaseRequest(ctx context.Context, req *http.Request) (interface{}, error) { + var request struct { + S string `json:"s"` + } + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return nil, err + } + return request, nil +} + +func decodeCountRequest(ctx context.Context, req *http.Request) (interface{}, error) { + var request struct { + S string `json:"s"` + } + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return nil, err + } + return request, nil } diff --git a/examples/profilesvc/client/client.go b/examples/profilesvc/client/client.go new file mode 100644 index 000000000..6b1dff064 --- /dev/null +++ b/examples/profilesvc/client/client.go @@ -0,0 +1,120 @@ +// Package client provides a profilesvc client based on a predefined Consul +// service name and relevant tags. Users must only provide the address of a +// Consul server. +package client + +import ( + "io" + "time" + + consulapi "github.com/hashicorp/consul/api" + + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/examples/profilesvc" + "github.com/go-kit/kit/log" + "github.com/go-kit/kit/sd" + "github.com/go-kit/kit/sd/consul" + "github.com/go-kit/kit/sd/lb" +) + +// New returns a service that's load-balanced over instances of profilesvc found +// in the provided Consul server. The mechanism of looking up profilesvc +// instances in Consul is hard-coded into the client. +func New(consulAddr string, logger log.Logger) (profilesvc.Service, error) { + apiclient, err := consulapi.NewClient(&consulapi.Config{ + Address: consulAddr, + }) + if err != nil { + return nil, err + } + + // As the implementer of profilesvc, we declare and enforce these + // parameters for all of the profilesvc consumers. + var ( + consulService = "profilesvc" + consulTags = []string{"prod"} + passingOnly = true + retryMax = 3 + retryTimeout = 500 * time.Millisecond + ) + + var ( + sdclient = consul.NewClient(apiclient) + endpoints profilesvc.Endpoints + ) + { + factory := factoryFor(profilesvc.MakePostProfileEndpoint) + subscriber := consul.NewSubscriber(sdclient, factory, logger, consulService, consulTags, passingOnly) + balancer := lb.NewRoundRobin(subscriber) + retry := lb.Retry(retryMax, retryTimeout, balancer) + endpoints.PostProfileEndpoint = retry + } + { + factory := factoryFor(profilesvc.MakeGetProfileEndpoint) + subscriber := consul.NewSubscriber(sdclient, factory, logger, consulService, consulTags, passingOnly) + balancer := lb.NewRoundRobin(subscriber) + retry := lb.Retry(retryMax, retryTimeout, balancer) + endpoints.GetProfileEndpoint = retry + } + { + factory := factoryFor(profilesvc.MakePutProfileEndpoint) + subscriber := consul.NewSubscriber(sdclient, factory, logger, consulService, consulTags, passingOnly) + balancer := lb.NewRoundRobin(subscriber) + retry := lb.Retry(retryMax, retryTimeout, balancer) + endpoints.PutProfileEndpoint = retry + } + { + factory := factoryFor(profilesvc.MakePatchProfileEndpoint) + subscriber := consul.NewSubscriber(sdclient, factory, logger, consulService, consulTags, passingOnly) + balancer := lb.NewRoundRobin(subscriber) + retry := lb.Retry(retryMax, retryTimeout, balancer) + endpoints.PatchProfileEndpoint = retry + } + { + factory := factoryFor(profilesvc.MakeDeleteProfileEndpoint) + subscriber := consul.NewSubscriber(sdclient, factory, logger, consulService, consulTags, passingOnly) + balancer := lb.NewRoundRobin(subscriber) + retry := lb.Retry(retryMax, retryTimeout, balancer) + endpoints.DeleteProfileEndpoint = retry + } + { + factory := factoryFor(profilesvc.MakeGetAddressesEndpoint) + subscriber := consul.NewSubscriber(sdclient, factory, logger, consulService, consulTags, passingOnly) + balancer := lb.NewRoundRobin(subscriber) + retry := lb.Retry(retryMax, retryTimeout, balancer) + endpoints.GetAddressesEndpoint = retry + } + { + factory := factoryFor(profilesvc.MakeGetAddressEndpoint) + subscriber := consul.NewSubscriber(sdclient, factory, logger, consulService, consulTags, passingOnly) + balancer := lb.NewRoundRobin(subscriber) + retry := lb.Retry(retryMax, retryTimeout, balancer) + endpoints.GetAddressEndpoint = retry + } + { + factory := factoryFor(profilesvc.MakePostAddressEndpoint) + subscriber := consul.NewSubscriber(sdclient, factory, logger, consulService, consulTags, passingOnly) + balancer := lb.NewRoundRobin(subscriber) + retry := lb.Retry(retryMax, retryTimeout, balancer) + endpoints.PostAddressEndpoint = retry + } + { + factory := factoryFor(profilesvc.MakeDeleteAddressEndpoint) + subscriber := consul.NewSubscriber(sdclient, factory, logger, consulService, consulTags, passingOnly) + balancer := lb.NewRoundRobin(subscriber) + retry := lb.Retry(retryMax, retryTimeout, balancer) + endpoints.DeleteAddressEndpoint = retry + } + + return endpoints, nil +} + +func factoryFor(makeEndpoint func(profilesvc.Service) endpoint.Endpoint) sd.Factory { + return func(instance string) (endpoint.Endpoint, io.Closer, error) { + service, err := profilesvc.MakeClientEndpoints(instance) + if err != nil { + return nil, nil, err + } + return makeEndpoint(service), nil, nil + } +} diff --git a/examples/profilesvc/main.go b/examples/profilesvc/cmd/profilesvc/main.go similarity index 63% rename from examples/profilesvc/main.go rename to examples/profilesvc/cmd/profilesvc/main.go index 5dfc082d8..a340e69da 100644 --- a/examples/profilesvc/main.go +++ b/examples/profilesvc/cmd/profilesvc/main.go @@ -10,6 +10,7 @@ import ( "golang.org/x/net/context" + "github.com/go-kit/kit/examples/profilesvc" "github.com/go-kit/kit/log" ) @@ -31,27 +32,28 @@ func main() { ctx = context.Background() } - var s ProfileService + var s profilesvc.Service { - s = newInmemService() - s = loggingMiddleware{s, log.NewContext(logger).With("component", "svc")} + s = profilesvc.NewInmemService() + s = profilesvc.LoggingMiddleware(logger)(s) } var h http.Handler { - h = makeHandler(ctx, s, log.NewContext(logger).With("component", "http")) + h = profilesvc.MakeHTTPHandler(ctx, s, log.NewContext(logger).With("component", "HTTP")) } - errs := make(chan error, 2) - go func() { - logger.Log("transport", "http", "address", *httpAddr, "msg", "listening") - errs <- http.ListenAndServe(*httpAddr, h) - }() + errs := make(chan error) go func() { c := make(chan os.Signal) - signal.Notify(c, syscall.SIGINT) + signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) errs <- fmt.Errorf("%s", <-c) }() - logger.Log("terminated", <-errs) + go func() { + logger.Log("transport", "HTTP", "addr", *httpAddr) + errs <- http.ListenAndServe(*httpAddr, h) + }() + + logger.Log("exit", <-errs) } diff --git a/examples/profilesvc/endpoints.go b/examples/profilesvc/endpoints.go index 062cf9afb..6dd129f83 100644 --- a/examples/profilesvc/endpoints.go +++ b/examples/profilesvc/endpoints.go @@ -1,58 +1,192 @@ -package main +package profilesvc import ( - "github.com/go-kit/kit/endpoint" + "net/url" + "strings" + "golang.org/x/net/context" + + "github.com/go-kit/kit/endpoint" + httptransport "github.com/go-kit/kit/transport/http" ) -type endpoints struct { - postProfileEndpoint endpoint.Endpoint - getProfileEndpoint endpoint.Endpoint - putProfileEndpoint endpoint.Endpoint - patchProfileEndpoint endpoint.Endpoint - deleteProfileEndpoint endpoint.Endpoint - getAddressesEndpoint endpoint.Endpoint - getAddressEndpoint endpoint.Endpoint - postAddressEndpoint endpoint.Endpoint - deleteAddressEndpoint endpoint.Endpoint +// Endpoints collects all of the endpoints that compose a profile service. It's +// meant to be used as a helper struct, to collect all of the endpoints into a +// single parameter. +// +// In a server, it's useful for functions that need to operate on a per-endpoint +// basis. For example, you might pass an Endpoints to a function that produces +// an http.Handler, with each method (endpoint) wired up to a specific path. (It +// is probably a mistake in design to invoke the Service methods on the +// Endpoints struct in a server.) +// +// In a client, it's useful to collect individually constructed endpoints into a +// single type that implements the Service interface. For example, you might +// construct individual endpoints using transport/http.NewClient, combine them +// into an Endpoints, and return it to the caller as a Service. +type Endpoints struct { + PostProfileEndpoint endpoint.Endpoint + GetProfileEndpoint endpoint.Endpoint + PutProfileEndpoint endpoint.Endpoint + PatchProfileEndpoint endpoint.Endpoint + DeleteProfileEndpoint endpoint.Endpoint + GetAddressesEndpoint endpoint.Endpoint + GetAddressEndpoint endpoint.Endpoint + PostAddressEndpoint endpoint.Endpoint + DeleteAddressEndpoint endpoint.Endpoint } -func makeEndpoints(s ProfileService) endpoints { - return endpoints{ - postProfileEndpoint: makePostProfileEndpoint(s), - getProfileEndpoint: makeGetProfileEndpoint(s), - putProfileEndpoint: makePutProfileEndpoint(s), - patchProfileEndpoint: makePatchProfileEndpoint(s), - deleteProfileEndpoint: makeDeleteProfileEndpoint(s), - getAddressesEndpoint: makeGetAddressesEndpoint(s), - getAddressEndpoint: makeGetAddressEndpoint(s), - postAddressEndpoint: makePostAddressEndpoint(s), - deleteAddressEndpoint: makeDeleteAddressEndpoint(s), +// MakeServerEndpoints returns an Endpoints struct where each endpoint invokes +// the corresponding method on the provided service. Useful in a profilesvc +// server. +func MakeServerEndpoints(s Service) Endpoints { + return Endpoints{ + PostProfileEndpoint: MakePostProfileEndpoint(s), + GetProfileEndpoint: MakeGetProfileEndpoint(s), + PutProfileEndpoint: MakePutProfileEndpoint(s), + PatchProfileEndpoint: MakePatchProfileEndpoint(s), + DeleteProfileEndpoint: MakeDeleteProfileEndpoint(s), + GetAddressesEndpoint: MakeGetAddressesEndpoint(s), + GetAddressEndpoint: MakeGetAddressEndpoint(s), + PostAddressEndpoint: MakePostAddressEndpoint(s), + DeleteAddressEndpoint: MakeDeleteAddressEndpoint(s), } } -type postProfileRequest struct { - Profile Profile +// MakeClientEndpoints returns an Endpoints struct where each endpoint invokes +// the corresponding method on the remote instance, via a transport/http.Client. +// Useful in a profilesvc client. +func MakeClientEndpoints(instance string) (Endpoints, error) { + if !strings.HasPrefix(instance, "http") { + instance = "http://" + instance + } + tgt, err := url.Parse(instance) + if err != nil { + return Endpoints{}, err + } + tgt.Path = "" + + options := []httptransport.ClientOption{} + + // Note that the request encoders need to modify the request URL, changing + // the path and method. That's fine: we simply need to provide specific + // encoders for each endpoint. + + return Endpoints{ + PostProfileEndpoint: httptransport.NewClient("POST", tgt, encodePostProfileRequest, decodePostProfileResponse, options...).Endpoint(), + GetProfileEndpoint: httptransport.NewClient("GET", tgt, encodeGetProfileRequest, decodeGetProfileResponse, options...).Endpoint(), + PutProfileEndpoint: httptransport.NewClient("PUT", tgt, encodePutProfileRequest, decodePutProfileResponse, options...).Endpoint(), + PatchProfileEndpoint: httptransport.NewClient("PATCH", tgt, encodePatchProfileRequest, decodePatchProfileResponse, options...).Endpoint(), + DeleteProfileEndpoint: httptransport.NewClient("DELETE", tgt, encodeDeleteProfileRequest, decodeDeleteProfileResponse, options...).Endpoint(), + GetAddressesEndpoint: httptransport.NewClient("GET", tgt, encodeGetAddressesRequest, decodeGetAddressesResponse, options...).Endpoint(), + GetAddressEndpoint: httptransport.NewClient("GET", tgt, encodeGetAddressRequest, decodeGetAddressResponse, options...).Endpoint(), + PostAddressEndpoint: httptransport.NewClient("POST", tgt, encodePostAddressRequest, decodePostAddressResponse, options...).Endpoint(), + DeleteAddressEndpoint: httptransport.NewClient("DELETE", tgt, encodeDeleteAddressRequest, decodeDeleteAddressResponse, options...).Endpoint(), + }, nil } -type postProfileResponse struct { - Err error `json:"err,omitempty"` +// PostProfile implements Service. Primarily useful in a client. +func (e Endpoints) PostProfile(ctx context.Context, p Profile) error { + request := postProfileRequest{Profile: p} + response, err := e.PostProfileEndpoint(ctx, request) + if err != nil { + return err + } + resp := response.(postProfileResponse) + return resp.Err } -func (r postProfileResponse) error() error { return r.Err } +// GetProfile implements Service. Primarily useful in a client. +func (e Endpoints) GetProfile(ctx context.Context, id string) (Profile, error) { + request := getProfileRequest{ID: id} + response, err := e.GetProfileEndpoint(ctx, request) + if err != nil { + return Profile{}, err + } + resp := response.(getProfileResponse) + return resp.Profile, resp.Err +} + +// PutProfile implements Service. Primarily useful in a client. +func (e Endpoints) PutProfile(ctx context.Context, id string, p Profile) error { + request := putProfileRequest{ID: id, Profile: p} + response, err := e.PutProfileEndpoint(ctx, request) + if err != nil { + return err + } + resp := response.(putProfileResponse) + return resp.Err +} + +// PatchProfile implements Service. Primarily useful in a client. +func (e Endpoints) PatchProfile(ctx context.Context, id string, p Profile) error { + request := patchProfileRequest{ID: id, Profile: p} + response, err := e.PatchProfileEndpoint(ctx, request) + if err != nil { + return err + } + resp := response.(patchProfileResponse) + return resp.Err +} -// Regarding errors returned from service (business logic) methods, we have two -// options. We could return the error via the endpoint itself. That makes -// certain things a little bit easier, like providing non-200 HTTP responses to -// the client. But Go kit assumes that endpoint errors are (or may be treated -// as) transport-domain errors. For example, an endpoint error will count -// against a circuit breaker error count. Therefore, it's almost certainly -// better to return service (business logic) errors in the response object. This -// means we have to do a bit more work in the HTTP response encoder to detect -// e.g. a not-found error and provide a proper HTTP status code. That work is -// done with the errorer interface, in transport.go. - -func makePostProfileEndpoint(s ProfileService) endpoint.Endpoint { +// DeleteProfile implements Service. Primarily useful in a client. +func (e Endpoints) DeleteProfile(ctx context.Context, id string) error { + request := deleteProfileRequest{ID: id} + response, err := e.DeleteProfileEndpoint(ctx, request) + if err != nil { + return err + } + resp := response.(deleteProfileResponse) + return resp.Err +} + +// GetAddresses implements Service. Primarily useful in a client. +func (e Endpoints) GetAddresses(ctx context.Context, profileID string) ([]Address, error) { + request := getAddressesRequest{ProfileID: profileID} + response, err := e.GetAddressesEndpoint(ctx, request) + if err != nil { + return nil, err + } + resp := response.(getAddressesResponse) + return resp.Addresses, resp.Err +} + +// GetAddress implements Service. Primarily useful in a client. +func (e Endpoints) GetAddress(ctx context.Context, profileID string, addressID string) (Address, error) { + request := getAddressRequest{ProfileID: profileID, AddressID: addressID} + response, err := e.GetAddressEndpoint(ctx, request) + if err != nil { + return Address{}, err + } + resp := response.(getAddressResponse) + return resp.Address, resp.Err +} + +// PostAddress implements Service. Primarily useful in a client. +func (e Endpoints) PostAddress(ctx context.Context, profileID string, a Address) error { + request := postAddressRequest{ProfileID: profileID, Address: a} + response, err := e.PostAddressEndpoint(ctx, request) + if err != nil { + return err + } + resp := response.(postAddressResponse) + return resp.Err +} + +// DeleteAddress implements Service. Primarily useful in a client. +func (e Endpoints) DeleteAddress(ctx context.Context, profileID string, addressID string) error { + request := deleteAddressRequest{ProfileID: profileID, AddressID: addressID} + response, err := e.DeleteAddressEndpoint(ctx, request) + if err != nil { + return err + } + resp := response.(deleteAddressResponse) + return resp.Err +} + +// MakePostProfileEndpoint returns an endpoint via the passed service. +// Primarily useful in a server. +func MakePostProfileEndpoint(s Service) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (response interface{}, err error) { req := request.(postProfileRequest) e := s.PostProfile(ctx, req.Profile) @@ -60,6 +194,111 @@ func makePostProfileEndpoint(s ProfileService) endpoint.Endpoint { } } +// MakeGetProfileEndpoint returns an endpoint via the passed service. +// Primarily useful in a server. +func MakeGetProfileEndpoint(s Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (response interface{}, err error) { + req := request.(getProfileRequest) + p, e := s.GetProfile(ctx, req.ID) + return getProfileResponse{Profile: p, Err: e}, nil + } +} + +// MakePutProfileEndpoint returns an endpoint via the passed service. +// Primarily useful in a server. +func MakePutProfileEndpoint(s Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (response interface{}, err error) { + req := request.(putProfileRequest) + e := s.PutProfile(ctx, req.ID, req.Profile) + return putProfileResponse{Err: e}, nil + } +} + +// MakePatchProfileEndpoint returns an endpoint via the passed service. +// Primarily useful in a server. +func MakePatchProfileEndpoint(s Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (response interface{}, err error) { + req := request.(patchProfileRequest) + e := s.PatchProfile(ctx, req.ID, req.Profile) + return patchProfileResponse{Err: e}, nil + } +} + +// MakeDeleteProfileEndpoint returns an endpoint via the passed service. +// Primarily useful in a server. +func MakeDeleteProfileEndpoint(s Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (response interface{}, err error) { + req := request.(deleteProfileRequest) + e := s.DeleteProfile(ctx, req.ID) + return deleteProfileResponse{Err: e}, nil + } +} + +// MakeGetAddressesEndpoint returns an endpoint via the passed service. +// Primarily useful in a server. +func MakeGetAddressesEndpoint(s Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (response interface{}, err error) { + req := request.(getAddressesRequest) + a, e := s.GetAddresses(ctx, req.ProfileID) + return getAddressesResponse{Addresses: a, Err: e}, nil + } +} + +// MakeGetAddressEndpoint returns an endpoint via the passed service. +// Primarily useful in a server. +func MakeGetAddressEndpoint(s Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (response interface{}, err error) { + req := request.(getAddressRequest) + a, e := s.GetAddress(ctx, req.ProfileID, req.AddressID) + return getAddressResponse{Address: a, Err: e}, nil + } +} + +// MakePostAddressEndpoint returns an endpoint via the passed service. +// Primarily useful in a server. +func MakePostAddressEndpoint(s Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (response interface{}, err error) { + req := request.(postAddressRequest) + e := s.PostAddress(ctx, req.ProfileID, req.Address) + return postAddressResponse{Err: e}, nil + } +} + +// MakeDeleteAddressEndpoint returns an endpoint via the passed service. +// Primarily useful in a server. +func MakeDeleteAddressEndpoint(s Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (response interface{}, err error) { + req := request.(deleteAddressRequest) + e := s.DeleteAddress(ctx, req.ProfileID, req.AddressID) + return deleteAddressResponse{Err: e}, nil + } +} + +// We have two options to return errors from the business logic. +// +// We could return the error via the endpoint itself. That makes certain things +// a little bit easier, like providing non-200 HTTP responses to the client. But +// Go kit assumes that endpoint errors are (or may be treated as) +// transport-domain errors. For example, an endpoint error will count against a +// circuit breaker error count. +// +// Therefore, it's often better to return service (business logic) errors in the +// response object. This means we have to do a bit more work in the HTTP +// response encoder to detect e.g. a not-found error and provide a proper HTTP +// status code. That work is done with the errorer interface, in transport.go. +// Response types that may contain business-logic errors implement that +// interface. + +type postProfileRequest struct { + Profile Profile +} + +type postProfileResponse struct { + Err error `json:"err,omitempty"` +} + +func (r postProfileResponse) error() error { return r.Err } + type getProfileRequest struct { ID string } @@ -71,14 +310,6 @@ type getProfileResponse struct { func (r getProfileResponse) error() error { return r.Err } -func makeGetProfileEndpoint(s ProfileService) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (response interface{}, err error) { - req := request.(getProfileRequest) - p, e := s.GetProfile(ctx, req.ID) - return getProfileResponse{Profile: p, Err: e}, nil - } -} - type putProfileRequest struct { ID string Profile Profile @@ -90,14 +321,6 @@ type putProfileResponse struct { func (r putProfileResponse) error() error { return nil } -func makePutProfileEndpoint(s ProfileService) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (response interface{}, err error) { - req := request.(putProfileRequest) - e := s.PutProfile(ctx, req.ID, req.Profile) - return putProfileResponse{Err: e}, nil - } -} - type patchProfileRequest struct { ID string Profile Profile @@ -109,14 +332,6 @@ type patchProfileResponse struct { func (r patchProfileResponse) error() error { return r.Err } -func makePatchProfileEndpoint(s ProfileService) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (response interface{}, err error) { - req := request.(patchProfileRequest) - e := s.PatchProfile(ctx, req.ID, req.Profile) - return patchProfileResponse{Err: e}, nil - } -} - type deleteProfileRequest struct { ID string } @@ -127,14 +342,6 @@ type deleteProfileResponse struct { func (r deleteProfileResponse) error() error { return r.Err } -func makeDeleteProfileEndpoint(s ProfileService) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (response interface{}, err error) { - req := request.(deleteProfileRequest) - e := s.DeleteProfile(ctx, req.ID) - return deleteProfileResponse{Err: e}, nil - } -} - type getAddressesRequest struct { ProfileID string } @@ -146,14 +353,6 @@ type getAddressesResponse struct { func (r getAddressesResponse) error() error { return r.Err } -func makeGetAddressesEndpoint(s ProfileService) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (response interface{}, err error) { - req := request.(getAddressesRequest) - a, e := s.GetAddresses(ctx, req.ProfileID) - return getAddressesResponse{Addresses: a, Err: e}, nil - } -} - type getAddressRequest struct { ProfileID string AddressID string @@ -166,14 +365,6 @@ type getAddressResponse struct { func (r getAddressResponse) error() error { return r.Err } -func makeGetAddressEndpoint(s ProfileService) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (response interface{}, err error) { - req := request.(getAddressRequest) - a, e := s.GetAddress(ctx, req.ProfileID, req.AddressID) - return getAddressResponse{Address: a, Err: e}, nil - } -} - type postAddressRequest struct { ProfileID string Address Address @@ -185,14 +376,6 @@ type postAddressResponse struct { func (r postAddressResponse) error() error { return r.Err } -func makePostAddressEndpoint(s ProfileService) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (response interface{}, err error) { - req := request.(postAddressRequest) - e := s.PostAddress(ctx, req.ProfileID, req.Address) - return postAddressResponse{Err: e}, nil - } -} - type deleteAddressRequest struct { ProfileID string AddressID string @@ -203,11 +386,3 @@ type deleteAddressResponse struct { } func (r deleteAddressResponse) error() error { return r.Err } - -func makeDeleteAddressEndpoint(s ProfileService) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (response interface{}, err error) { - req := request.(deleteAddressRequest) - e := s.DeleteAddress(ctx, req.ProfileID, req.AddressID) - return deleteAddressResponse{Err: e}, nil - } -} diff --git a/examples/profilesvc/middlewares.go b/examples/profilesvc/middlewares.go index 698e2683f..76708e594 100644 --- a/examples/profilesvc/middlewares.go +++ b/examples/profilesvc/middlewares.go @@ -1,4 +1,4 @@ -package main +package profilesvc import ( "time" @@ -8,8 +8,20 @@ import ( "github.com/go-kit/kit/log" ) +// Middleware describes a service (as opposed to endpoint) middleware. +type Middleware func(Service) Service + +func LoggingMiddleware(logger log.Logger) Middleware { + return func(next Service) Service { + return &loggingMiddleware{ + next: next, + logger: logger, + } + } +} + type loggingMiddleware struct { - next ProfileService + next Service logger log.Logger } diff --git a/examples/profilesvc/service.go b/examples/profilesvc/service.go index 9e121002e..4ae67561f 100644 --- a/examples/profilesvc/service.go +++ b/examples/profilesvc/service.go @@ -1,4 +1,4 @@ -package main +package profilesvc import ( "errors" @@ -7,8 +7,8 @@ import ( "golang.org/x/net/context" ) -// ProfileService is a simple CRUD interface for user profiles. -type ProfileService interface { +// Service is a simple CRUD interface for user profiles. +type Service interface { PostProfile(ctx context.Context, p Profile) error GetProfile(ctx context.Context, id string) (Profile, error) PutProfile(ctx context.Context, id string, p Profile) error @@ -36,9 +36,9 @@ type Address struct { } var ( - errInconsistentIDs = errors.New("inconsistent IDs") - errAlreadyExists = errors.New("already exists") - errNotFound = errors.New("not found") + ErrInconsistentIDs = errors.New("inconsistent IDs") + ErrAlreadyExists = errors.New("already exists") + ErrNotFound = errors.New("not found") ) type inmemService struct { @@ -46,7 +46,7 @@ type inmemService struct { m map[string]Profile } -func newInmemService() ProfileService { +func NewInmemService() Service { return &inmemService{ m: map[string]Profile{}, } @@ -56,7 +56,7 @@ func (s *inmemService) PostProfile(ctx context.Context, p Profile) error { s.mtx.Lock() defer s.mtx.Unlock() if _, ok := s.m[p.ID]; ok { - return errAlreadyExists // POST = create, don't overwrite + return ErrAlreadyExists // POST = create, don't overwrite } s.m[p.ID] = p return nil @@ -67,14 +67,14 @@ func (s *inmemService) GetProfile(ctx context.Context, id string) (Profile, erro defer s.mtx.RUnlock() p, ok := s.m[id] if !ok { - return Profile{}, errNotFound + return Profile{}, ErrNotFound } return p, nil } func (s *inmemService) PutProfile(ctx context.Context, id string, p Profile) error { if id != p.ID { - return errInconsistentIDs + return ErrInconsistentIDs } s.mtx.Lock() defer s.mtx.Unlock() @@ -84,7 +84,7 @@ func (s *inmemService) PutProfile(ctx context.Context, id string, p Profile) err func (s *inmemService) PatchProfile(ctx context.Context, id string, p Profile) error { if p.ID != "" && id != p.ID { - return errInconsistentIDs + return ErrInconsistentIDs } s.mtx.Lock() @@ -92,7 +92,7 @@ func (s *inmemService) PatchProfile(ctx context.Context, id string, p Profile) e existing, ok := s.m[id] if !ok { - return errNotFound // PATCH = update existing, don't create + return ErrNotFound // PATCH = update existing, don't create } // We assume that it's not possible to PATCH the ID, and that it's not @@ -115,7 +115,7 @@ func (s *inmemService) DeleteProfile(ctx context.Context, id string) error { s.mtx.Lock() defer s.mtx.Unlock() if _, ok := s.m[id]; !ok { - return errNotFound + return ErrNotFound } delete(s.m, id) return nil @@ -126,7 +126,7 @@ func (s *inmemService) GetAddresses(ctx context.Context, profileID string) ([]Ad defer s.mtx.RUnlock() p, ok := s.m[profileID] if !ok { - return []Address{}, errNotFound + return []Address{}, ErrNotFound } return p.Addresses, nil } @@ -136,14 +136,14 @@ func (s *inmemService) GetAddress(ctx context.Context, profileID string, address defer s.mtx.RUnlock() p, ok := s.m[profileID] if !ok { - return Address{}, errNotFound + return Address{}, ErrNotFound } for _, address := range p.Addresses { if address.ID == addressID { return address, nil } } - return Address{}, errNotFound + return Address{}, ErrNotFound } func (s *inmemService) PostAddress(ctx context.Context, profileID string, a Address) error { @@ -151,11 +151,11 @@ func (s *inmemService) PostAddress(ctx context.Context, profileID string, a Addr defer s.mtx.Unlock() p, ok := s.m[profileID] if !ok { - return errNotFound + return ErrNotFound } for _, address := range p.Addresses { if address.ID == a.ID { - return errAlreadyExists + return ErrAlreadyExists } } p.Addresses = append(p.Addresses, a) @@ -168,7 +168,7 @@ func (s *inmemService) DeleteAddress(ctx context.Context, profileID string, addr defer s.mtx.Unlock() p, ok := s.m[profileID] if !ok { - return errNotFound + return ErrNotFound } newAddresses := make([]Address, 0, len(p.Addresses)) for _, address := range p.Addresses { @@ -178,7 +178,7 @@ func (s *inmemService) DeleteAddress(ctx context.Context, profileID string, addr newAddresses = append(newAddresses, address) } if len(newAddresses) == len(p.Addresses) { - return errNotFound + return ErrNotFound } p.Addresses = newAddresses s.m[profileID] = p diff --git a/examples/profilesvc/transport.go b/examples/profilesvc/transport.go index 0874a1079..02d807c93 100644 --- a/examples/profilesvc/transport.go +++ b/examples/profilesvc/transport.go @@ -1,28 +1,37 @@ -package main +package profilesvc + +// The profilesvc is just over HTTP, so we just have a single transport.go. import ( + "bytes" "encoding/json" "errors" - stdhttp "net/http" + "io/ioutil" + "net/http" "github.com/gorilla/mux" "golang.org/x/net/context" - kitlog "github.com/go-kit/kit/log" - kithttp "github.com/go-kit/kit/transport/http" + "net/url" + + "github.com/go-kit/kit/log" + httptransport "github.com/go-kit/kit/transport/http" ) var ( - errBadRouting = errors.New("inconsistent mapping between route and handler (programmer error)") + // ErrBadRouting is returned when an expected path variable is missing. + // It always indicates programmer error. + ErrBadRouting = errors.New("inconsistent mapping between route and handler (programmer error)") ) -func makeHandler(ctx context.Context, s ProfileService, logger kitlog.Logger) stdhttp.Handler { - e := makeEndpoints(s) +// MakeHTTPHandler mounts all of the service endpoints into an http.Handler. +// Useful in a profilesvc server. +func MakeHTTPHandler(ctx context.Context, s Service, logger log.Logger) http.Handler { r := mux.NewRouter() - - commonOptions := []kithttp.ServerOption{ - kithttp.ServerErrorLogger(logger), - kithttp.ServerErrorEncoder(encodeError), + e := MakeServerEndpoints(s) + options := []httptransport.ServerOption{ + httptransport.ServerErrorLogger(logger), + httptransport.ServerErrorEncoder(encodeError), } // POST /profiles adds another profile @@ -35,73 +44,73 @@ func makeHandler(ctx context.Context, s ProfileService, logger kitlog.Logger) st // POST /profiles/:id/addresses add a new address // DELETE /profiles/:id/addresses/:addressID remove an address - r.Methods("POST").Path("/profiles/").Handler(kithttp.NewServer( + r.Methods("POST").Path("/profiles/").Handler(httptransport.NewServer( ctx, - e.postProfileEndpoint, + e.PostProfileEndpoint, decodePostProfileRequest, encodeResponse, - commonOptions..., + options..., )) - r.Methods("GET").Path("/profiles/{id}").Handler(kithttp.NewServer( + r.Methods("GET").Path("/profiles/{id}").Handler(httptransport.NewServer( ctx, - e.getProfileEndpoint, + e.GetProfileEndpoint, decodeGetProfileRequest, encodeResponse, - commonOptions..., + options..., )) - r.Methods("PUT").Path("/profiles/{id}").Handler(kithttp.NewServer( + r.Methods("PUT").Path("/profiles/{id}").Handler(httptransport.NewServer( ctx, - e.putProfileEndpoint, + e.PutProfileEndpoint, decodePutProfileRequest, encodeResponse, - commonOptions..., + options..., )) - r.Methods("PATCH").Path("/profiles/{id}").Handler(kithttp.NewServer( + r.Methods("PATCH").Path("/profiles/{id}").Handler(httptransport.NewServer( ctx, - e.patchProfileEndpoint, + e.PatchProfileEndpoint, decodePatchProfileRequest, encodeResponse, - commonOptions..., + options..., )) - r.Methods("DELETE").Path("/profiles/{id}").Handler(kithttp.NewServer( + r.Methods("DELETE").Path("/profiles/{id}").Handler(httptransport.NewServer( ctx, - e.deleteProfileEndpoint, + e.DeleteProfileEndpoint, decodeDeleteProfileRequest, encodeResponse, - commonOptions..., + options..., )) - r.Methods("GET").Path("/profiles/{id}/addresses/").Handler(kithttp.NewServer( + r.Methods("GET").Path("/profiles/{id}/addresses/").Handler(httptransport.NewServer( ctx, - e.getAddressesEndpoint, + e.GetAddressesEndpoint, decodeGetAddressesRequest, encodeResponse, - commonOptions..., + options..., )) - r.Methods("GET").Path("/profiles/{id}/addresses/{addressID}").Handler(kithttp.NewServer( + r.Methods("GET").Path("/profiles/{id}/addresses/{addressID}").Handler(httptransport.NewServer( ctx, - e.getAddressEndpoint, + e.GetAddressEndpoint, decodeGetAddressRequest, encodeResponse, - commonOptions..., + options..., )) - r.Methods("POST").Path("/profiles/{id}/addresses/").Handler(kithttp.NewServer( + r.Methods("POST").Path("/profiles/{id}/addresses/").Handler(httptransport.NewServer( ctx, - e.postAddressEndpoint, + e.PostAddressEndpoint, decodePostAddressRequest, encodeResponse, - commonOptions..., + options..., )) - r.Methods("DELETE").Path("/profiles/{id}/addresses/{addressID}").Handler(kithttp.NewServer( + r.Methods("DELETE").Path("/profiles/{id}/addresses/{addressID}").Handler(httptransport.NewServer( ctx, - e.deleteAddressEndpoint, + e.DeleteAddressEndpoint, decodeDeleteAddressRequest, encodeResponse, - commonOptions..., + options..., )) return r } -func decodePostProfileRequest(_ context.Context, r *stdhttp.Request) (request interface{}, err error) { +func decodePostProfileRequest(_ context.Context, r *http.Request) (request interface{}, err error) { var req postProfileRequest if e := json.NewDecoder(r.Body).Decode(&req.Profile); e != nil { return nil, e @@ -109,20 +118,20 @@ func decodePostProfileRequest(_ context.Context, r *stdhttp.Request) (request in return req, nil } -func decodeGetProfileRequest(_ context.Context, r *stdhttp.Request) (request interface{}, err error) { +func decodeGetProfileRequest(_ context.Context, r *http.Request) (request interface{}, err error) { vars := mux.Vars(r) id, ok := vars["id"] if !ok { - return nil, errBadRouting + return nil, ErrBadRouting } return getProfileRequest{ID: id}, nil } -func decodePutProfileRequest(_ context.Context, r *stdhttp.Request) (request interface{}, err error) { +func decodePutProfileRequest(_ context.Context, r *http.Request) (request interface{}, err error) { vars := mux.Vars(r) id, ok := vars["id"] if !ok { - return nil, errBadRouting + return nil, ErrBadRouting } var profile Profile if err := json.NewDecoder(r.Body).Decode(&profile); err != nil { @@ -134,11 +143,11 @@ func decodePutProfileRequest(_ context.Context, r *stdhttp.Request) (request int }, nil } -func decodePatchProfileRequest(_ context.Context, r *stdhttp.Request) (request interface{}, err error) { +func decodePatchProfileRequest(_ context.Context, r *http.Request) (request interface{}, err error) { vars := mux.Vars(r) id, ok := vars["id"] if !ok { - return nil, errBadRouting + return nil, ErrBadRouting } var profile Profile if err := json.NewDecoder(r.Body).Decode(&profile); err != nil { @@ -150,33 +159,33 @@ func decodePatchProfileRequest(_ context.Context, r *stdhttp.Request) (request i }, nil } -func decodeDeleteProfileRequest(_ context.Context, r *stdhttp.Request) (request interface{}, err error) { +func decodeDeleteProfileRequest(_ context.Context, r *http.Request) (request interface{}, err error) { vars := mux.Vars(r) id, ok := vars["id"] if !ok { - return nil, errBadRouting + return nil, ErrBadRouting } return deleteProfileRequest{ID: id}, nil } -func decodeGetAddressesRequest(_ context.Context, r *stdhttp.Request) (request interface{}, err error) { +func decodeGetAddressesRequest(_ context.Context, r *http.Request) (request interface{}, err error) { vars := mux.Vars(r) id, ok := vars["id"] if !ok { - return nil, errBadRouting + return nil, ErrBadRouting } return getAddressesRequest{ProfileID: id}, nil } -func decodeGetAddressRequest(_ context.Context, r *stdhttp.Request) (request interface{}, err error) { +func decodeGetAddressRequest(_ context.Context, r *http.Request) (request interface{}, err error) { vars := mux.Vars(r) id, ok := vars["id"] if !ok { - return nil, errBadRouting + return nil, ErrBadRouting } addressID, ok := vars["addressID"] if !ok { - return nil, errBadRouting + return nil, ErrBadRouting } return getAddressRequest{ ProfileID: id, @@ -184,11 +193,11 @@ func decodeGetAddressRequest(_ context.Context, r *stdhttp.Request) (request int }, nil } -func decodePostAddressRequest(_ context.Context, r *stdhttp.Request) (request interface{}, err error) { +func decodePostAddressRequest(_ context.Context, r *http.Request) (request interface{}, err error) { vars := mux.Vars(r) id, ok := vars["id"] if !ok { - return nil, errBadRouting + return nil, ErrBadRouting } var address Address if err := json.NewDecoder(r.Body).Decode(&address); err != nil { @@ -200,15 +209,15 @@ func decodePostAddressRequest(_ context.Context, r *stdhttp.Request) (request in }, nil } -func decodeDeleteAddressRequest(_ context.Context, r *stdhttp.Request) (request interface{}, err error) { +func decodeDeleteAddressRequest(_ context.Context, r *http.Request) (request interface{}, err error) { vars := mux.Vars(r) id, ok := vars["id"] if !ok { - return nil, errBadRouting + return nil, ErrBadRouting } addressID, ok := vars["addressID"] if !ok { - return nil, errBadRouting + return nil, ErrBadRouting } return deleteAddressRequest{ ProfileID: id, @@ -216,32 +225,163 @@ func decodeDeleteAddressRequest(_ context.Context, r *stdhttp.Request) (request }, nil } -// errorer is implemented by all concrete response types. It allows us to -// change the HTTP response code without needing to trigger an endpoint -// (transport-level) error. For more information, read the big comment in -// endpoints.go. +func encodePostProfileRequest(ctx context.Context, req *http.Request, request interface{}) error { + // r.Methods("POST").Path("/profiles/") + req.Method, req.URL.Path = "POST", url.QueryEscape("/profiles/") + return encodeRequest(ctx, req, request) +} + +func encodeGetProfileRequest(ctx context.Context, req *http.Request, request interface{}) error { + // r.Methods("GET").Path("/profiles/{id}") + r := request.(getProfileRequest) + req.Method, req.URL.Path = "GET", url.QueryEscape("/profiles/"+r.ID) + return encodeRequest(ctx, req, request) +} + +func encodePutProfileRequest(ctx context.Context, req *http.Request, request interface{}) error { + // r.Methods("PUT").Path("/profiles/{id}") + r := request.(putProfileRequest) + req.Method, req.URL.Path = "PUT", url.QueryEscape("/profiles/"+r.ID) + return encodeRequest(ctx, req, request) +} + +func encodePatchProfileRequest(ctx context.Context, req *http.Request, request interface{}) error { + // r.Methods("PATCH").Path("/profiles/{id}") + r := request.(patchProfileRequest) + req.Method, req.URL.Path = "PATCH", url.QueryEscape("/profiles/"+r.ID) + return encodeRequest(ctx, req, request) +} + +func encodeDeleteProfileRequest(ctx context.Context, req *http.Request, request interface{}) error { + // r.Methods("DELETE").Path("/profiles/{id}") + r := request.(deleteProfileRequest) + req.Method, req.URL.Path = "DELETE", url.QueryEscape("/profiles/"+r.ID) + return encodeRequest(ctx, req, request) +} + +func encodeGetAddressesRequest(ctx context.Context, req *http.Request, request interface{}) error { + // r.Methods("GET").Path("/profiles/{id}/addresses/") + r := request.(getAddressesRequest) + req.Method, req.URL.Path = "GET", url.QueryEscape("/profiles/"+r.ProfileID+"/addresses/") + return encodeRequest(ctx, req, request) +} + +func encodeGetAddressRequest(ctx context.Context, req *http.Request, request interface{}) error { + // r.Methods("GET").Path("/profiles/{id}/addresses/{addressID}") + r := request.(getAddressRequest) + req.Method, req.URL.Path = "GET", url.QueryEscape("/profiles/"+r.ProfileID+"/addresses/"+r.AddressID) + return encodeRequest(ctx, req, request) +} + +func encodePostAddressRequest(ctx context.Context, req *http.Request, request interface{}) error { + // r.Methods("POST").Path("/profiles/{id}/addresses/") + r := request.(postAddressRequest) + req.Method, req.URL.Path = "POST", url.QueryEscape("/profiles/"+r.ProfileID+"/addresses/") + return encodeRequest(ctx, req, request) +} + +func encodeDeleteAddressRequest(ctx context.Context, req *http.Request, request interface{}) error { + // r.Methods("DELETE").Path("/profiles/{id}/addresses/{addressID}") + r := request.(deleteAddressRequest) + req.Method, req.URL.Path = "DELETE", url.QueryEscape("/profiles/"+r.ProfileID+"/addresses/"+r.AddressID) + return encodeRequest(ctx, req, request) +} + +func decodePostProfileResponse(_ context.Context, resp *http.Response) (interface{}, error) { + var response postProfileResponse + err := json.NewDecoder(resp.Body).Decode(&response) + return response, err +} + +func decodeGetProfileResponse(_ context.Context, resp *http.Response) (interface{}, error) { + var response getProfileResponse + err := json.NewDecoder(resp.Body).Decode(&response) + return response, err +} + +func decodePutProfileResponse(_ context.Context, resp *http.Response) (interface{}, error) { + var response putProfileResponse + err := json.NewDecoder(resp.Body).Decode(&response) + return response, err +} + +func decodePatchProfileResponse(_ context.Context, resp *http.Response) (interface{}, error) { + var response patchProfileResponse + err := json.NewDecoder(resp.Body).Decode(&response) + return response, err +} + +func decodeDeleteProfileResponse(_ context.Context, resp *http.Response) (interface{}, error) { + var response deleteProfileResponse + err := json.NewDecoder(resp.Body).Decode(&response) + return response, err +} + +func decodeGetAddressesResponse(_ context.Context, resp *http.Response) (interface{}, error) { + var response getAddressesResponse + err := json.NewDecoder(resp.Body).Decode(&response) + return response, err +} + +func decodeGetAddressResponse(_ context.Context, resp *http.Response) (interface{}, error) { + var response getAddressResponse + err := json.NewDecoder(resp.Body).Decode(&response) + return response, err +} + +func decodePostAddressResponse(_ context.Context, resp *http.Response) (interface{}, error) { + var response postAddressResponse + err := json.NewDecoder(resp.Body).Decode(&response) + return response, err +} + +func decodeDeleteAddressResponse(_ context.Context, resp *http.Response) (interface{}, error) { + var response deleteAddressResponse + err := json.NewDecoder(resp.Body).Decode(&response) + return response, err +} + +// errorer is implemented by all concrete response types that may contain +// errors. It allows us to change the HTTP response code without needing to +// trigger an endpoint (transport-level) error. For more information, read the +// big comment in endpoints.go. type errorer interface { error() error } // encodeResponse is the common method to encode all response types to the -// client. I chose to do it this way because I didn't know if something more -// specific was necessary. It's certainly possible to specialize on a -// per-response (per-method) basis. -func encodeResponse(ctx context.Context, w stdhttp.ResponseWriter, response interface{}) error { +// client. I chose to do it this way because, since we're using JSON, there's no +// reason to provide anything more specific. It's certainly possible to +// specialize on a per-response (per-method) basis. +func encodeResponse(ctx context.Context, w http.ResponseWriter, response interface{}) error { if e, ok := response.(errorer); ok && e.error() != nil { // Not a Go kit transport error, but a business-logic error. // Provide those as HTTP errors. encodeError(ctx, e.error(), w) return nil } + w.Header().Set("Content-Type", "application/json; charset=utf-8") return json.NewEncoder(w).Encode(response) } -func encodeError(_ context.Context, err error, w stdhttp.ResponseWriter) { +// encodeRequest likewise JSON-encodes the request to the HTTP request body. +// Don't use it directly as a transport/http.Client EncodeRequestFunc: +// profilesvc endpoints require mutating the HTTP method and request path. +func encodeRequest(_ context.Context, req *http.Request, request interface{}) error { + var buf bytes.Buffer + err := json.NewEncoder(&buf).Encode(request) + if err != nil { + return err + } + req.Body = ioutil.NopCloser(&buf) + return nil +} + +func encodeError(_ context.Context, err error, w http.ResponseWriter) { if err == nil { panic("encodeError with nil error") } + w.Header().Set("Content-Type", "application/json; charset=utf-8") w.WriteHeader(codeFrom(err)) json.NewEncoder(w).Encode(map[string]interface{}{ "error": err.Error(), @@ -250,21 +390,21 @@ func encodeError(_ context.Context, err error, w stdhttp.ResponseWriter) { func codeFrom(err error) int { switch err { - case errNotFound: - return stdhttp.StatusNotFound - case errAlreadyExists, errInconsistentIDs: - return stdhttp.StatusBadRequest + case ErrNotFound: + return http.StatusNotFound + case ErrAlreadyExists, ErrInconsistentIDs: + return http.StatusBadRequest default: - if e, ok := err.(kithttp.Error); ok { + if e, ok := err.(httptransport.Error); ok { switch e.Domain { - case kithttp.DomainDecode: - return stdhttp.StatusBadRequest - case kithttp.DomainDo: - return stdhttp.StatusServiceUnavailable + case httptransport.DomainDecode: + return http.StatusBadRequest + case httptransport.DomainDo: + return http.StatusServiceUnavailable default: - return stdhttp.StatusInternalServerError + return http.StatusInternalServerError } } - return stdhttp.StatusInternalServerError + return http.StatusInternalServerError } } diff --git a/examples/stringsvc2/instrumenting.go b/examples/stringsvc2/instrumenting.go index c3da27cac..f46184575 100644 --- a/examples/stringsvc2/instrumenting.go +++ b/examples/stringsvc2/instrumenting.go @@ -11,7 +11,7 @@ type instrumentingMiddleware struct { requestCount metrics.Counter requestLatency metrics.TimeHistogram countResult metrics.Histogram - StringService + next StringService } func (mw instrumentingMiddleware) Uppercase(s string) (output string, err error) { @@ -22,7 +22,7 @@ func (mw instrumentingMiddleware) Uppercase(s string) (output string, err error) mw.requestLatency.With(methodField).With(errorField).Observe(time.Since(begin)) }(time.Now()) - output, err = mw.StringService.Uppercase(s) + output, err = mw.next.Uppercase(s) return } @@ -35,6 +35,6 @@ func (mw instrumentingMiddleware) Count(s string) (n int) { mw.countResult.Observe(int64(n)) }(time.Now()) - n = mw.StringService.Count(s) + n = mw.next.Count(s) return } diff --git a/examples/stringsvc2/logging.go b/examples/stringsvc2/logging.go index 67fec5da0..b958f3b6f 100644 --- a/examples/stringsvc2/logging.go +++ b/examples/stringsvc2/logging.go @@ -8,7 +8,7 @@ import ( type loggingMiddleware struct { logger log.Logger - StringService + next StringService } func (mw loggingMiddleware) Uppercase(s string) (output string, err error) { @@ -22,7 +22,7 @@ func (mw loggingMiddleware) Uppercase(s string) (output string, err error) { ) }(time.Now()) - output, err = mw.StringService.Uppercase(s) + output, err = mw.next.Uppercase(s) return } @@ -36,6 +36,6 @@ func (mw loggingMiddleware) Count(s string) (n int) { ) }(time.Now()) - n = mw.StringService.Count(s) + n = mw.next.Count(s) return } diff --git a/examples/stringsvc3/proxying.go b/examples/stringsvc3/proxying.go index 78b75508c..33bc1563e 100644 --- a/examples/stringsvc3/proxying.go +++ b/examples/stringsvc3/proxying.go @@ -3,7 +3,6 @@ package main import ( "errors" "fmt" - "io" "net/url" "strings" "time" @@ -14,45 +13,70 @@ import ( "github.com/go-kit/kit/circuitbreaker" "github.com/go-kit/kit/endpoint" - "github.com/go-kit/kit/loadbalancer" - "github.com/go-kit/kit/loadbalancer/static" "github.com/go-kit/kit/log" - kitratelimit "github.com/go-kit/kit/ratelimit" + "github.com/go-kit/kit/ratelimit" + "github.com/go-kit/kit/sd" + "github.com/go-kit/kit/sd/lb" httptransport "github.com/go-kit/kit/transport/http" ) -func proxyingMiddleware(proxyList string, ctx context.Context, logger log.Logger) ServiceMiddleware { - if proxyList == "" { +func proxyingMiddleware(instances string, ctx context.Context, logger log.Logger) ServiceMiddleware { + // If instances is empty, don't proxy. + if instances == "" { logger.Log("proxy_to", "none") return func(next StringService) StringService { return next } } - proxies := split(proxyList) - logger.Log("proxy_to", fmt.Sprint(proxies)) + // Set some parameters for our client. + var ( + qps = 100 // beyond which we will return an error + maxAttempts = 3 // per request, before giving up + maxTime = 250 * time.Millisecond // wallclock time, before giving up + ) + + // Otherwise, construct an endpoint for each instance in the list, and add + // it to a fixed set of endpoints. In a real service, rather than doing this + // by hand, you'd probably use package sd's support for your service + // discovery system. + var ( + instanceList = split(instances) + subscriber sd.FixedSubscriber + ) + logger.Log("proxy_to", fmt.Sprint(instanceList)) + for _, instance := range instanceList { + var e endpoint.Endpoint + e = makeUppercaseProxy(ctx, instance) + e = circuitbreaker.Gobreaker(gobreaker.NewCircuitBreaker(gobreaker.Settings{}))(e) + e = ratelimit.NewTokenBucketLimiter(jujuratelimit.NewBucketWithRate(float64(qps), int64(qps)))(e) + subscriber = append(subscriber, e) + } + + // Now, build a single, retrying, load-balancing endpoint out of all of + // those individual endpoints. + balancer := lb.NewRoundRobin(subscriber) + retry := lb.Retry(maxAttempts, maxTime, balancer) + + // And finally, return the ServiceMiddleware, implemented by proxymw. return func(next StringService) StringService { - var ( - qps = 100 // max to each instance - publisher = static.NewPublisher(proxies, factory(ctx, qps), logger) - lb = loadbalancer.NewRoundRobin(publisher) - maxAttempts = 3 - maxTime = 100 * time.Millisecond - endpoint = loadbalancer.Retry(maxAttempts, maxTime, lb) - ) - return proxymw{ctx, endpoint, next} + return proxymw{ctx, next, retry} } } // proxymw implements StringService, forwarding Uppercase requests to the // provided endpoint, and serving all other (i.e. Count) requests via the -// embedded StringService. +// next StringService. type proxymw struct { - context.Context - UppercaseEndpoint endpoint.Endpoint - StringService + ctx context.Context + next StringService // Serve most requests via this service... + uppercase endpoint.Endpoint // ...except Uppercase, which gets served by this endpoint +} + +func (mw proxymw) Count(s string) int { + return mw.next.Count(s) } func (mw proxymw) Uppercase(s string) (string, error) { - response, err := mw.UppercaseEndpoint(mw.Context, uppercaseRequest{S: s}) + response, err := mw.uppercase(mw.ctx, uppercaseRequest{S: s}) if err != nil { return "", err } @@ -64,16 +88,6 @@ func (mw proxymw) Uppercase(s string) (string, error) { return resp.V, nil } -func factory(ctx context.Context, qps int) loadbalancer.Factory { - return func(instance string) (endpoint.Endpoint, io.Closer, error) { - var e endpoint.Endpoint - e = makeUppercaseProxy(ctx, instance) - e = circuitbreaker.Gobreaker(gobreaker.NewCircuitBreaker(gobreaker.Settings{}))(e) - e = kitratelimit.NewTokenBucketLimiter(jujuratelimit.NewBucketWithRate(float64(qps), int64(qps)))(e) - return e, nil, nil - } -} - func makeUppercaseProxy(ctx context.Context, instance string) endpoint.Endpoint { if !strings.HasPrefix(instance, "http") { instance = "http://" + instance diff --git a/loadbalancer/README.md b/loadbalancer/README.md deleted file mode 100644 index ac8a88ed4..000000000 --- a/loadbalancer/README.md +++ /dev/null @@ -1,67 +0,0 @@ -# package loadbalancer - -`package loadbalancer` provides a client-side load balancer abstraction. - -A publisher is responsible for emitting the most recent set of endpoints for a -single logical service. Publishers exist for static endpoints, and endpoints -discovered via periodic DNS SRV lookups on a single logical name. Consul and -etcd publishers are planned. - -Different load balancers are implemented on top of publishers. Go kit -currently provides random and round-robin load balancers. Smarter behaviors, -e.g. load balancing based on underlying endpoint priority/weight, is planned. - -## Rationale - -TODO - -## Usage - -In your client, construct a publisher for a specific remote service, and pass -it to a load balancer. Then, request an endpoint from the load balancer -whenever you need to make a request to that remote service. - -```go -import ( - "github.com/go-kit/kit/loadbalancer" - "github.com/go-kit/kit/loadbalancer/dnssrv" -) - -func main() { - // Construct a load balancer for foosvc, which gets foosvc instances by - // polling a specific DNS SRV name. - p, err := dnssrv.NewPublisher("foosvc.internal.domain", 5*time.Second, fooFactory, logger) - if err != nil { - panic(err) - } - - lb := loadbalancer.NewRoundRobin(p) - - // Get a new endpoint from the load balancer. - endpoint, err := lb.Endpoint() - if err != nil { - panic(err) - } - - // Use the endpoint to make a request. - response, err := endpoint(ctx, request) -} - -func fooFactory(instance string) (endpoint.Endpoint, error) { - // Convert an instance (host:port) to an endpoint, via a defined transport binding. -} -``` - -It's also possible to wrap a load balancer with a retry strategy, so that it -can be used as an endpoint directly. This may make load balancers more -convenient to use, at the cost of fine-grained control of failures. - -```go -func main() { - p := dnssrv.NewPublisher("foosvc.internal.domain", 5*time.Second, fooFactory, logger) - lb := loadbalancer.NewRoundRobin(p) - endpoint := loadbalancer.Retry(3, 5*time.Seconds, lb) - - response, err := endpoint(ctx, request) // requests will be automatically load balanced -} -``` diff --git a/loadbalancer/consul/client.go b/loadbalancer/consul/client.go deleted file mode 100644 index 7f5af3e8a..000000000 --- a/loadbalancer/consul/client.go +++ /dev/null @@ -1,30 +0,0 @@ -package consul - -import consul "github.com/hashicorp/consul/api" - -// Client is a wrapper around the Consul API. -type Client interface { - Service(service string, tag string, queryOpts *consul.QueryOptions) ([]*consul.ServiceEntry, *consul.QueryMeta, error) -} - -type client struct { - consul *consul.Client -} - -// NewClient returns an implementation of the Client interface expecting a fully -// setup Consul Client. -func NewClient(c *consul.Client) Client { - return &client{ - consul: c, - } -} - -// GetInstances returns the list of healthy entries for a given service filtered -// by tag. -func (c *client) Service( - service string, - tag string, - opts *consul.QueryOptions, -) ([]*consul.ServiceEntry, *consul.QueryMeta, error) { - return c.consul.Health().Service(service, tag, true, opts) -} diff --git a/loadbalancer/consul/publisher.go b/loadbalancer/consul/publisher.go deleted file mode 100644 index eb64904d6..000000000 --- a/loadbalancer/consul/publisher.go +++ /dev/null @@ -1,174 +0,0 @@ -package consul - -import ( - "fmt" - "strings" - - consul "github.com/hashicorp/consul/api" - - "github.com/go-kit/kit/endpoint" - "github.com/go-kit/kit/loadbalancer" - "github.com/go-kit/kit/log" -) - -const defaultIndex = 0 - -// Publisher yields endpoints for a service in Consul. Updates to the service -// are watched and will update the Publisher endpoints. -type Publisher struct { - cache *loadbalancer.EndpointCache - client Client - logger log.Logger - service string - tags []string - endpointsc chan []endpoint.Endpoint - quitc chan struct{} -} - -// NewPublisher returns a Consul publisher which returns Endpoints for the -// requested service. It only returns instances for which all of the passed -// tags are present. -func NewPublisher( - client Client, - factory loadbalancer.Factory, - logger log.Logger, - service string, - tags ...string, -) (*Publisher, error) { - p := &Publisher{ - cache: loadbalancer.NewEndpointCache(factory, logger), - client: client, - logger: logger, - service: service, - tags: tags, - quitc: make(chan struct{}), - } - - instances, index, err := p.getInstances(defaultIndex) - if err == nil { - logger.Log("service", service, "tags", strings.Join(tags, ", "), "instances", len(instances)) - } else { - logger.Log("service", service, "tags", strings.Join(tags, ", "), "err", err) - } - p.cache.Replace(instances) - - go p.loop(index) - - return p, nil -} - -// Endpoints implements the Publisher interface. -func (p *Publisher) Endpoints() ([]endpoint.Endpoint, error) { - return p.cache.Endpoints() -} - -// Stop terminates the publisher. -func (p *Publisher) Stop() { - close(p.quitc) -} - -func (p *Publisher) loop(lastIndex uint64) { - var ( - errc = make(chan error, 1) - resc = make(chan response, 1) - ) - - for { - go func() { - instances, index, err := p.getInstances(lastIndex) - if err != nil { - errc <- err - return - } - resc <- response{ - index: index, - instances: instances, - } - }() - - select { - case err := <-errc: - p.logger.Log("service", p.service, "err", err) - case res := <-resc: - p.cache.Replace(res.instances) - lastIndex = res.index - case <-p.quitc: - return - } - } -} - -func (p *Publisher) getInstances(lastIndex uint64) ([]string, uint64, error) { - tag := "" - - if len(p.tags) > 0 { - tag = p.tags[0] - } - - entries, meta, err := p.client.Service( - p.service, - tag, - &consul.QueryOptions{ - WaitIndex: lastIndex, - }, - ) - if err != nil { - return nil, 0, err - } - - // If more than one tag is passed we need to filter it in the publisher until - // Consul supports multiple tags[0]. - // - // [0] https://github.com/hashicorp/consul/issues/294 - if len(p.tags) > 1 { - entries = filterEntries(entries, p.tags[1:]...) - } - - return makeInstances(entries), meta.LastIndex, nil -} - -// response is used as container to transport instances as well as the updated -// index. -type response struct { - index uint64 - instances []string -} - -func filterEntries(entries []*consul.ServiceEntry, tags ...string) []*consul.ServiceEntry { - var es []*consul.ServiceEntry - -ENTRIES: - for _, entry := range entries { - ts := make(map[string]struct{}, len(entry.Service.Tags)) - - for _, tag := range entry.Service.Tags { - ts[tag] = struct{}{} - } - - for _, tag := range tags { - if _, ok := ts[tag]; !ok { - continue ENTRIES - } - } - - es = append(es, entry) - } - - return es -} - -func makeInstances(entries []*consul.ServiceEntry) []string { - instances := make([]string, len(entries)) - - for i, entry := range entries { - addr := entry.Node.Address - - if entry.Service.Address != "" { - addr = entry.Service.Address - } - - instances[i] = fmt.Sprintf("%s:%d", addr, entry.Service.Port) - } - - return instances -} diff --git a/loadbalancer/consul/publisher_test.go b/loadbalancer/consul/publisher_test.go deleted file mode 100644 index 23f4f1b6f..000000000 --- a/loadbalancer/consul/publisher_test.go +++ /dev/null @@ -1,207 +0,0 @@ -package consul - -import ( - "io" - "testing" - - consul "github.com/hashicorp/consul/api" - "golang.org/x/net/context" - - "github.com/go-kit/kit/endpoint" - "github.com/go-kit/kit/log" -) - -var consulState = []*consul.ServiceEntry{ - { - Node: &consul.Node{ - Address: "10.0.0.0", - Node: "app00.local", - }, - Service: &consul.AgentService{ - ID: "search-api-0", - Port: 8000, - Service: "search", - Tags: []string{ - "api", - "v1", - }, - }, - }, - { - Node: &consul.Node{ - Address: "10.0.0.1", - Node: "app01.local", - }, - Service: &consul.AgentService{ - ID: "search-api-1", - Port: 8001, - Service: "search", - Tags: []string{ - "api", - "v2", - }, - }, - }, - { - Node: &consul.Node{ - Address: "10.0.0.1", - Node: "app01.local", - }, - Service: &consul.AgentService{ - Address: "10.0.0.10", - ID: "search-db-0", - Port: 9000, - Service: "search", - Tags: []string{ - "db", - }, - }, - }, -} - -func TestPublisher(t *testing.T) { - var ( - logger = log.NewNopLogger() - client = newTestClient(consulState) - ) - - p, err := NewPublisher(client, testFactory, logger, "search", "api") - if err != nil { - t.Fatalf("publisher setup failed: %s", err) - } - defer p.Stop() - - eps, err := p.Endpoints() - if err != nil { - t.Fatalf("endpoints failed: %s", err) - } - - if have, want := len(eps), 2; have != want { - t.Errorf("have %v, want %v", have, want) - } -} - -func TestPublisherNoService(t *testing.T) { - var ( - logger = log.NewNopLogger() - client = newTestClient(consulState) - ) - - p, err := NewPublisher(client, testFactory, logger, "feed") - if err != nil { - t.Fatalf("publisher setup failed: %s", err) - } - defer p.Stop() - - eps, err := p.Endpoints() - if err != nil { - t.Fatalf("endpoints failed: %s", err) - } - - if have, want := len(eps), 0; have != want { - t.Fatalf("have %v, want %v", have, want) - } -} - -func TestPublisherWithTags(t *testing.T) { - var ( - logger = log.NewNopLogger() - client = newTestClient(consulState) - ) - - p, err := NewPublisher(client, testFactory, logger, "search", "api", "v2") - if err != nil { - t.Fatalf("publisher setup failed: %s", err) - } - defer p.Stop() - - eps, err := p.Endpoints() - if err != nil { - t.Fatalf("endpoints failed: %s", err) - } - - if have, want := len(eps), 1; have != want { - t.Fatalf("have %v, want %v", have, want) - } -} - -func TestPublisherAddressOverride(t *testing.T) { - var ( - ctx = context.Background() - logger = log.NewNopLogger() - client = newTestClient(consulState) - ) - - p, err := NewPublisher(client, testFactory, logger, "search", "db") - if err != nil { - t.Fatalf("publisher setup failed: %s", err) - } - defer p.Stop() - - eps, err := p.Endpoints() - if err != nil { - t.Fatalf("endpoints failed: %s", err) - } - - if have, want := len(eps), 1; have != want { - t.Fatalf("have %v, want %v", have, want) - } - - ins, err := eps[0](ctx, struct{}{}) - if err != nil { - t.Fatal(err) - } - - if have, want := ins.(string), "10.0.0.10:9000"; have != want { - t.Errorf("have %#v, want %#v", have, want) - } -} - -type testClient struct { - entries []*consul.ServiceEntry -} - -func newTestClient(entries []*consul.ServiceEntry) Client { - if entries == nil { - entries = []*consul.ServiceEntry{} - } - - return &testClient{ - entries: entries, - } -} - -func (c *testClient) Service( - service string, - tag string, - opts *consul.QueryOptions, -) ([]*consul.ServiceEntry, *consul.QueryMeta, error) { - es := []*consul.ServiceEntry{} - - for _, e := range c.entries { - if e.Service.Service != service { - continue - } - if tag != "" { - tagMap := map[string]struct{}{} - - for _, t := range e.Service.Tags { - tagMap[t] = struct{}{} - } - - if _, ok := tagMap[tag]; !ok { - continue - } - } - - es = append(es, e) - } - - return es, &consul.QueryMeta{}, nil -} - -func testFactory(ins string) (endpoint.Endpoint, io.Closer, error) { - return func(context.Context, interface{}) (interface{}, error) { - return ins, nil - }, nil, nil -} diff --git a/loadbalancer/dnssrv/publisher.go b/loadbalancer/dnssrv/publisher.go deleted file mode 100644 index 62d6b4d21..000000000 --- a/loadbalancer/dnssrv/publisher.go +++ /dev/null @@ -1,106 +0,0 @@ -package dnssrv - -import ( - "fmt" - "net" - "time" - - "github.com/go-kit/kit/endpoint" - "github.com/go-kit/kit/loadbalancer" - "github.com/go-kit/kit/log" -) - -// Publisher yields endpoints taken from the named DNS SRV record. The name is -// resolved on a fixed schedule. Priorities and weights are ignored. -type Publisher struct { - name string - cache *loadbalancer.EndpointCache - logger log.Logger - quit chan struct{} -} - -// NewPublisher returns a DNS SRV publisher. The name is resolved -// synchronously as part of construction; if that resolution fails, the -// constructor will return an error. The factory is used to convert a -// host:port to a usable endpoint. The logger is used to report DNS and -// factory errors. -func NewPublisher( - name string, - ttl time.Duration, - factory loadbalancer.Factory, - logger log.Logger, -) *Publisher { - return NewPublisherDetailed(name, time.NewTicker(ttl), net.LookupSRV, factory, logger) -} - -// NewPublisherDetailed is the same as NewPublisher, but allows users to provide -// an explicit lookup refresh ticker instead of a TTL, and specify the function -// used to perform lookups instead of using net.LookupSRV. -func NewPublisherDetailed( - name string, - refreshTicker *time.Ticker, - lookupSRV func(service, proto, name string) (cname string, addrs []*net.SRV, err error), - factory loadbalancer.Factory, - logger log.Logger, -) *Publisher { - p := &Publisher{ - name: name, - cache: loadbalancer.NewEndpointCache(factory, logger), - logger: logger, - quit: make(chan struct{}), - } - - instances, err := p.resolve(lookupSRV) - if err == nil { - logger.Log("name", name, "instances", len(instances)) - } else { - logger.Log("name", name, "err", err) - } - p.cache.Replace(instances) - - go p.loop(refreshTicker, lookupSRV) - return p -} - -// Stop terminates the publisher. -func (p *Publisher) Stop() { - close(p.quit) -} - -func (p *Publisher) loop( - refreshTicker *time.Ticker, - lookupSRV func(service, proto, name string) (cname string, addrs []*net.SRV, err error), -) { - defer refreshTicker.Stop() - for { - select { - case <-refreshTicker.C: - instances, err := p.resolve(lookupSRV) - if err != nil { - p.logger.Log(p.name, err) - continue // don't replace potentially-good with bad - } - p.cache.Replace(instances) - - case <-p.quit: - return - } - } -} - -// Endpoints implements the Publisher interface. -func (p *Publisher) Endpoints() ([]endpoint.Endpoint, error) { - return p.cache.Endpoints() -} - -func (p *Publisher) resolve(lookupSRV func(service, proto, name string) (cname string, addrs []*net.SRV, err error)) ([]string, error) { - _, addrs, err := lookupSRV("", "", p.name) - if err != nil { - return []string{}, err - } - instances := make([]string, len(addrs)) - for i, addr := range addrs { - instances[i] = net.JoinHostPort(addr.Target, fmt.Sprint(addr.Port)) - } - return instances, nil -} diff --git a/loadbalancer/dnssrv/publisher_test.go b/loadbalancer/dnssrv/publisher_test.go deleted file mode 100644 index 363b7d2be..000000000 --- a/loadbalancer/dnssrv/publisher_test.go +++ /dev/null @@ -1,133 +0,0 @@ -package dnssrv - -import ( - "errors" - "io" - "net" - "sync/atomic" - "testing" - "time" - - "golang.org/x/net/context" - - "github.com/go-kit/kit/endpoint" - "github.com/go-kit/kit/log" -) - -func TestPublisher(t *testing.T) { - var ( - name = "foo" - ttl = time.Second - e = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } - factory = func(string) (endpoint.Endpoint, io.Closer, error) { return e, nil, nil } - logger = log.NewNopLogger() - ) - - p := NewPublisher(name, ttl, factory, logger) - defer p.Stop() - - if _, err := p.Endpoints(); err != nil { - t.Fatal(err) - } -} - -func TestBadLookup(t *testing.T) { - var ( - name = "some-name" - ticker = time.NewTicker(time.Second) - lookups = uint32(0) - lookupSRV = func(string, string, string) (string, []*net.SRV, error) { - atomic.AddUint32(&lookups, 1) - return "", nil, errors.New("kaboom") - } - e = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } - factory = func(string) (endpoint.Endpoint, io.Closer, error) { return e, nil, nil } - logger = log.NewNopLogger() - ) - - p := NewPublisherDetailed(name, ticker, lookupSRV, factory, logger) - defer p.Stop() - - endpoints, err := p.Endpoints() - if err != nil { - t.Error(err) - } - if want, have := 0, len(endpoints); want != have { - t.Errorf("want %d, have %d", want, have) - } - if want, have := uint32(1), atomic.LoadUint32(&lookups); want != have { - t.Errorf("want %d, have %d", want, have) - } -} - -func TestBadFactory(t *testing.T) { - var ( - name = "some-name" - ticker = time.NewTicker(time.Second) - addr = &net.SRV{Target: "foo", Port: 1234} - addrs = []*net.SRV{addr} - lookupSRV = func(a, b, c string) (string, []*net.SRV, error) { return "", addrs, nil } - creates = uint32(0) - factory = func(s string) (endpoint.Endpoint, io.Closer, error) { - atomic.AddUint32(&creates, 1) - return nil, nil, errors.New("kaboom") - } - logger = log.NewNopLogger() - ) - - p := NewPublisherDetailed(name, ticker, lookupSRV, factory, logger) - defer p.Stop() - - endpoints, err := p.Endpoints() - if err != nil { - t.Error(err) - } - if want, have := 0, len(endpoints); want != have { - t.Errorf("want %q, have %q", want, have) - } - if want, have := uint32(1), atomic.LoadUint32(&creates); want != have { - t.Errorf("want %d, have %d", want, have) - } -} - -func TestRefreshWithChange(t *testing.T) { - t.Skip("TODO") -} - -func TestRefreshNoChange(t *testing.T) { - var ( - addr = &net.SRV{Target: "my-target", Port: 5678} - addrs = []*net.SRV{addr} - name = "my-name" - ticker = time.NewTicker(time.Second) - lookups = uint32(0) - lookupSRV = func(string, string, string) (string, []*net.SRV, error) { - atomic.AddUint32(&lookups, 1) - return "", addrs, nil - } - e = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } - factory = func(string) (endpoint.Endpoint, io.Closer, error) { return e, nil, nil } - logger = log.NewNopLogger() - ) - - ticker.Stop() - tickc := make(chan time.Time) - ticker.C = tickc - - p := NewPublisherDetailed(name, ticker, lookupSRV, factory, logger) - defer p.Stop() - - if want, have := uint32(1), atomic.LoadUint32(&lookups); want != have { - t.Errorf("want %d, have %d", want, have) - } - - tickc <- time.Now() - - if want, have := uint32(2), atomic.LoadUint32(&lookups); want != have { - t.Errorf("want %d, have %d", want, have) - } -} - -func TestRefreshResolveError(t *testing.T) { - t.Skip("TODO") -} diff --git a/loadbalancer/endpoint_cache.go b/loadbalancer/endpoint_cache.go deleted file mode 100644 index df2b781d5..000000000 --- a/loadbalancer/endpoint_cache.go +++ /dev/null @@ -1,112 +0,0 @@ -package loadbalancer - -import ( - "io" - "sort" - "sync" - "sync/atomic" - - "github.com/go-kit/kit/endpoint" - "github.com/go-kit/kit/log" -) - -// EndpointCache caches endpoints that need to be deallocated when they're no -// longer useful. Clients update the cache by providing a current set of -// instance strings. The cache converts each instance string to an endpoint -// and a closer via the factory function. -// -// Instance strings are assumed to be unique and are used as keys. Endpoints -// that were in the previous set of instances and are not in the current set -// are considered invalid and closed. -// -// EndpointCache is designed to be used in your publisher implementation. -type EndpointCache struct { - mtx sync.Mutex - f Factory - m map[string]endpointCloser - cache atomic.Value //[]endpoint.Endpoint - logger log.Logger -} - -// NewEndpointCache produces a new EndpointCache, ready for use. Instance -// strings will be converted to endpoints via the provided factory function. -// The logger is used to log errors. -func NewEndpointCache(f Factory, logger log.Logger) *EndpointCache { - endpointCache := &EndpointCache{ - f: f, - m: map[string]endpointCloser{}, - logger: log.NewContext(logger).With("component", "Endpoint Cache"), - } - - endpointCache.cache.Store(make([]endpoint.Endpoint, 0)) - - return endpointCache -} - -type endpointCloser struct { - endpoint.Endpoint - io.Closer -} - -// Replace replaces the current set of endpoints with endpoints manufactured -// by the passed instances. If the same instance exists in both the existing -// and new sets, it's left untouched. -func (t *EndpointCache) Replace(instances []string) { - t.mtx.Lock() - defer t.mtx.Unlock() - - // Produce the current set of endpoints. - oldMap := t.m - t.m = make(map[string]endpointCloser, len(instances)) - for _, instance := range instances { - // If it already exists, just copy it over. - if ec, ok := oldMap[instance]; ok { - t.m[instance] = ec - delete(oldMap, instance) - continue - } - - // If it doesn't exist, create it. - endpoint, closer, err := t.f(instance) - if err != nil { - t.logger.Log("instance", instance, "err", err) - continue - } - t.m[instance] = endpointCloser{endpoint, closer} - } - - t.refreshCache() - - // Close any leftover endpoints. - for _, ec := range oldMap { - if ec.Closer != nil { - ec.Closer.Close() - } - } -} - -func (t *EndpointCache) refreshCache() { - var ( - length = len(t.m) - instances = make([]string, 0, length) - newCache = make([]endpoint.Endpoint, 0, length) - ) - - for instance, _ := range t.m { - instances = append(instances, instance) - } - // Sort the instances for ensuring that Endpoints are returned into the same order if no modified. - sort.Strings(instances) - - for _, instance := range instances { - newCache = append(newCache, t.m[instance].Endpoint) - } - - t.cache.Store(newCache) -} - -// Endpoints returns the current set of endpoints in undefined order. Satisfies -// Publisher interface. -func (t *EndpointCache) Endpoints() ([]endpoint.Endpoint, error) { - return t.cache.Load().([]endpoint.Endpoint), nil -} diff --git a/loadbalancer/endpoint_cache_test.go b/loadbalancer/endpoint_cache_test.go deleted file mode 100644 index 423a1f040..000000000 --- a/loadbalancer/endpoint_cache_test.go +++ /dev/null @@ -1,92 +0,0 @@ -package loadbalancer_test - -import ( - "io" - "testing" - "time" - - "golang.org/x/net/context" - - "github.com/go-kit/kit/endpoint" - "github.com/go-kit/kit/loadbalancer" - "github.com/go-kit/kit/log" -) - -func TestEndpointCache(t *testing.T) { - var ( - e = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } - ca = make(closer) - cb = make(closer) - c = map[string]io.Closer{"a": ca, "b": cb} - f = func(s string) (endpoint.Endpoint, io.Closer, error) { return e, c[s], nil } - ec = loadbalancer.NewEndpointCache(f, log.NewNopLogger()) - ) - - // Populate - ec.Replace([]string{"a", "b"}) - select { - case <-ca: - t.Errorf("endpoint a closed, not good") - case <-cb: - t.Errorf("endpoint b closed, not good") - case <-time.After(time.Millisecond): - t.Logf("no closures yet, good") - } - - // Duplicate, should be no-op - ec.Replace([]string{"a", "b"}) - select { - case <-ca: - t.Errorf("endpoint a closed, not good") - case <-cb: - t.Errorf("endpoint b closed, not good") - case <-time.After(time.Millisecond): - t.Logf("no closures yet, good") - } - - // Delete b - go ec.Replace([]string{"a"}) - select { - case <-ca: - t.Errorf("endpoint a closed, not good") - case <-cb: - t.Logf("endpoint b closed, good") - case <-time.After(time.Millisecond): - t.Errorf("didn't close the deleted instance in time") - } - - // Delete a - go ec.Replace([]string{""}) - select { - // case <-cb: will succeed, as it's closed - case <-ca: - t.Logf("endpoint a closed, good") - case <-time.After(time.Millisecond): - t.Errorf("didn't close the deleted instance in time") - } -} - -type closer chan struct{} - -func (c closer) Close() error { close(c); return nil } - -func BenchmarkEndpoints(b *testing.B) { - var ( - e = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } - ca = make(closer) - cb = make(closer) - c = map[string]io.Closer{"a": ca, "b": cb} - f = func(s string) (endpoint.Endpoint, io.Closer, error) { return e, c[s], nil } - ec = loadbalancer.NewEndpointCache(f, log.NewNopLogger()) - ) - - b.ReportAllocs() - - ec.Replace([]string{"a", "b"}) - - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - ec.Endpoints() - } - }) -} \ No newline at end of file diff --git a/loadbalancer/etcd/publisher.go b/loadbalancer/etcd/publisher.go deleted file mode 100644 index 37c8aac47..000000000 --- a/loadbalancer/etcd/publisher.go +++ /dev/null @@ -1,71 +0,0 @@ -package etcd - -import ( - etcd "github.com/coreos/etcd/client" - - "github.com/go-kit/kit/endpoint" - "github.com/go-kit/kit/loadbalancer" - "github.com/go-kit/kit/log" -) - -// Publisher yield endpoints stored in a certain etcd keyspace. Any kind of -// change in that keyspace is watched and will update the Publisher endpoints. -type Publisher struct { - client Client - prefix string - cache *loadbalancer.EndpointCache - logger log.Logger - quit chan struct{} -} - -// NewPublisher returs a etcd publisher. Etcd will start watching the given -// prefix for changes and update the Publisher endpoints. -func NewPublisher(c Client, prefix string, f loadbalancer.Factory, logger log.Logger) (*Publisher, error) { - p := &Publisher{ - client: c, - prefix: prefix, - cache: loadbalancer.NewEndpointCache(f, logger), - logger: logger, - quit: make(chan struct{}), - } - - instances, err := p.client.GetEntries(p.prefix) - if err == nil { - logger.Log("prefix", p.prefix, "instances", len(instances)) - } else { - logger.Log("prefix", p.prefix, "err", err) - } - p.cache.Replace(instances) - - go p.loop() - return p, nil -} - -func (p *Publisher) loop() { - responseChan := make(chan *etcd.Response) - go p.client.WatchPrefix(p.prefix, responseChan) - for { - select { - case <-responseChan: - instances, err := p.client.GetEntries(p.prefix) - if err != nil { - p.logger.Log("msg", "failed to retrieve entries", "err", err) - continue - } - p.cache.Replace(instances) - - case <-p.quit: - return - } - } -} - -// Endpoints implements the Publisher interface. -func (p *Publisher) Endpoints() ([]endpoint.Endpoint, error) { - return p.cache.Endpoints() -} - -// Stop terminates the Publisher. -func (p *Publisher) Stop() { - close(p.quit) -} diff --git a/loadbalancer/factory.go b/loadbalancer/factory.go deleted file mode 100644 index 71a7be5d6..000000000 --- a/loadbalancer/factory.go +++ /dev/null @@ -1,15 +0,0 @@ -package loadbalancer - -import ( - "io" - - "github.com/go-kit/kit/endpoint" -) - -// Factory is a function that converts an instance string, e.g. a host:port, -// to a usable endpoint. Factories are used by load balancers to convert -// instances returned by Publishers (typically host:port strings) into -// endpoints. Users are expected to provide their own factory functions that -// assume specific transports, or can deduce transports by parsing the -// instance string. -type Factory func(instance string) (endpoint.Endpoint, io.Closer, error) diff --git a/loadbalancer/fixed/publisher.go b/loadbalancer/fixed/publisher.go deleted file mode 100644 index a4be875c2..000000000 --- a/loadbalancer/fixed/publisher.go +++ /dev/null @@ -1,35 +0,0 @@ -package fixed - -import ( - "sync" - - "github.com/go-kit/kit/endpoint" -) - -// Publisher yields the same set of fixed endpoints. -type Publisher struct { - mtx sync.RWMutex - endpoints []endpoint.Endpoint -} - -// NewPublisher returns a fixed endpoint Publisher. -func NewPublisher(endpoints []endpoint.Endpoint) *Publisher { - return &Publisher{ - endpoints: endpoints, - } -} - -// Endpoints implements the Publisher interface. -func (p *Publisher) Endpoints() ([]endpoint.Endpoint, error) { - p.mtx.RLock() - defer p.mtx.RUnlock() - return p.endpoints, nil -} - -// Replace is a utility method to swap out the underlying endpoints of an -// existing fixed publisher. It's useful mostly for testing. -func (p *Publisher) Replace(endpoints []endpoint.Endpoint) { - p.mtx.Lock() - defer p.mtx.Unlock() - p.endpoints = endpoints -} diff --git a/loadbalancer/fixed/publisher_test.go b/loadbalancer/fixed/publisher_test.go deleted file mode 100644 index 90c1b40a7..000000000 --- a/loadbalancer/fixed/publisher_test.go +++ /dev/null @@ -1,48 +0,0 @@ -package fixed_test - -import ( - "reflect" - "testing" - - "golang.org/x/net/context" - - "github.com/go-kit/kit/endpoint" - "github.com/go-kit/kit/loadbalancer/fixed" -) - -func TestFixed(t *testing.T) { - var ( - e1 = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } - e2 = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } - endpoints = []endpoint.Endpoint{e1, e2} - ) - p := fixed.NewPublisher(endpoints) - have, err := p.Endpoints() - if err != nil { - t.Fatal(err) - } - if want := endpoints; !reflect.DeepEqual(want, have) { - t.Fatalf("want %#+v, have %#+v", want, have) - } -} - -func TestFixedReplace(t *testing.T) { - p := fixed.NewPublisher([]endpoint.Endpoint{ - func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, - }) - have, err := p.Endpoints() - if err != nil { - t.Fatal(err) - } - if want, have := 1, len(have); want != have { - t.Fatalf("want %d, have %d", want, have) - } - p.Replace([]endpoint.Endpoint{}) - have, err = p.Endpoints() - if err != nil { - t.Fatal(err) - } - if want, have := 0, len(have); want != have { - t.Fatalf("want %d, have %d", want, have) - } -} diff --git a/loadbalancer/loadbalancer.go b/loadbalancer/loadbalancer.go deleted file mode 100644 index 6a99e66e8..000000000 --- a/loadbalancer/loadbalancer.go +++ /dev/null @@ -1,18 +0,0 @@ -package loadbalancer - -import ( - "errors" - - "github.com/go-kit/kit/endpoint" -) - -// LoadBalancer describes something that can yield endpoints for a remote -// service method. -type LoadBalancer interface { - Endpoint() (endpoint.Endpoint, error) -} - -// ErrNoEndpoints is returned when a load balancer (or one of its components) -// has no endpoints to return. In a request lifecycle, this is usually a fatal -// error. -var ErrNoEndpoints = errors.New("no endpoints available") diff --git a/loadbalancer/publisher.go b/loadbalancer/publisher.go deleted file mode 100644 index ec17d7e4c..000000000 --- a/loadbalancer/publisher.go +++ /dev/null @@ -1,10 +0,0 @@ -package loadbalancer - -import "github.com/go-kit/kit/endpoint" - -// Publisher describes something that provides a set of identical endpoints. -// Different publisher implementations exist for different kinds of service -// discovery systems. -type Publisher interface { - Endpoints() ([]endpoint.Endpoint, error) -} diff --git a/loadbalancer/random.go b/loadbalancer/random.go deleted file mode 100644 index dcab16531..000000000 --- a/loadbalancer/random.go +++ /dev/null @@ -1,34 +0,0 @@ -package loadbalancer - -import ( - "math/rand" - - "github.com/go-kit/kit/endpoint" -) - -// Random is a completely stateless load balancer that chooses a random -// endpoint to return each time. -type Random struct { - p Publisher - r *rand.Rand -} - -// NewRandom returns a new Random load balancer. -func NewRandom(p Publisher, seed int64) *Random { - return &Random{ - p: p, - r: rand.New(rand.NewSource(seed)), - } -} - -// Endpoint implements the LoadBalancer interface. -func (r *Random) Endpoint() (endpoint.Endpoint, error) { - endpoints, err := r.p.Endpoints() - if err != nil { - return nil, err - } - if len(endpoints) <= 0 { - return nil, ErrNoEndpoints - } - return endpoints[r.r.Intn(len(endpoints))], nil -} diff --git a/loadbalancer/random_test.go b/loadbalancer/random_test.go deleted file mode 100644 index e6c5f831c..000000000 --- a/loadbalancer/random_test.go +++ /dev/null @@ -1,60 +0,0 @@ -package loadbalancer_test - -import ( - "math" - "testing" - - "golang.org/x/net/context" - - "github.com/go-kit/kit/endpoint" - "github.com/go-kit/kit/loadbalancer" - "github.com/go-kit/kit/loadbalancer/fixed" -) - -func TestRandomDistribution(t *testing.T) { - var ( - n = 3 - endpoints = make([]endpoint.Endpoint, n) - counts = make([]int, n) - seed = int64(123) - ctx = context.Background() - iterations = 100000 - want = iterations / n - tolerance = want / 100 // 1% - ) - - for i := 0; i < n; i++ { - i0 := i - endpoints[i] = func(context.Context, interface{}) (interface{}, error) { counts[i0]++; return struct{}{}, nil } - } - - lb := loadbalancer.NewRandom(fixed.NewPublisher(endpoints), seed) - - for i := 0; i < iterations; i++ { - e, err := lb.Endpoint() - if err != nil { - t.Fatal(err) - } - if _, err := e(ctx, struct{}{}); err != nil { - t.Error(err) - } - } - - for i, have := range counts { - if math.Abs(float64(want-have)) > float64(tolerance) { - t.Errorf("%d: want %d, have %d", i, want, have) - } - } -} - -func TestRandomBadPublisher(t *testing.T) { - t.Skip("TODO") -} - -func TestRandomNoEndpoints(t *testing.T) { - lb := loadbalancer.NewRandom(fixed.NewPublisher([]endpoint.Endpoint{}), 123) - _, have := lb.Endpoint() - if want := loadbalancer.ErrNoEndpoints; want != have { - t.Errorf("want %q, have %q", want, have) - } -} diff --git a/loadbalancer/round_robin.go b/loadbalancer/round_robin.go deleted file mode 100644 index fe6d29d2d..000000000 --- a/loadbalancer/round_robin.go +++ /dev/null @@ -1,41 +0,0 @@ -package loadbalancer - -import ( - "sync/atomic" - - "github.com/go-kit/kit/endpoint" -) - -// RoundRobin is a simple load balancer that returns each of the published -// endpoints in sequence. -type RoundRobin struct { - p Publisher - counter uint64 -} - -// NewRoundRobin returns a new RoundRobin load balancer. -func NewRoundRobin(p Publisher) *RoundRobin { - return &RoundRobin{ - p: p, - counter: 0, - } -} - -// Endpoint implements the LoadBalancer interface. -func (rr *RoundRobin) Endpoint() (endpoint.Endpoint, error) { - endpoints, err := rr.p.Endpoints() - if err != nil { - return nil, err - } - if len(endpoints) <= 0 { - return nil, ErrNoEndpoints - } - var old uint64 - for { - old = atomic.LoadUint64(&rr.counter) - if atomic.CompareAndSwapUint64(&rr.counter, old, old+1) { - break - } - } - return endpoints[old%uint64(len(endpoints))], nil -} diff --git a/loadbalancer/round_robin_test.go b/loadbalancer/round_robin_test.go deleted file mode 100644 index beabe80d3..000000000 --- a/loadbalancer/round_robin_test.go +++ /dev/null @@ -1,51 +0,0 @@ -package loadbalancer_test - -import ( - "reflect" - "testing" - - "github.com/go-kit/kit/endpoint" - "github.com/go-kit/kit/loadbalancer" - "github.com/go-kit/kit/loadbalancer/fixed" - "golang.org/x/net/context" -) - -func TestRoundRobinDistribution(t *testing.T) { - var ( - ctx = context.Background() - counts = []int{0, 0, 0} - endpoints = []endpoint.Endpoint{ - func(context.Context, interface{}) (interface{}, error) { counts[0]++; return struct{}{}, nil }, - func(context.Context, interface{}) (interface{}, error) { counts[1]++; return struct{}{}, nil }, - func(context.Context, interface{}) (interface{}, error) { counts[2]++; return struct{}{}, nil }, - } - ) - - lb := loadbalancer.NewRoundRobin(fixed.NewPublisher(endpoints)) - - for i, want := range [][]int{ - {1, 0, 0}, - {1, 1, 0}, - {1, 1, 1}, - {2, 1, 1}, - {2, 2, 1}, - {2, 2, 2}, - {3, 2, 2}, - } { - e, err := lb.Endpoint() - if err != nil { - t.Fatal(err) - } - if _, err := e(ctx, struct{}{}); err != nil { - t.Error(err) - } - if have := counts; !reflect.DeepEqual(want, have) { - t.Fatalf("%d: want %v, have %v", i, want, have) - } - - } -} - -func TestRoundRobinBadPublisher(t *testing.T) { - t.Skip("TODO") -} diff --git a/loadbalancer/static/publisher.go b/loadbalancer/static/publisher.go deleted file mode 100644 index e5d552a6a..000000000 --- a/loadbalancer/static/publisher.go +++ /dev/null @@ -1,31 +0,0 @@ -package static - -import ( - "github.com/go-kit/kit/endpoint" - "github.com/go-kit/kit/loadbalancer" - "github.com/go-kit/kit/loadbalancer/fixed" - "github.com/go-kit/kit/log" -) - -// Publisher yields a set of static endpoints as produced by the passed factory. -type Publisher struct{ publisher *fixed.Publisher } - -// NewPublisher returns a static endpoint Publisher. -func NewPublisher(instances []string, factory loadbalancer.Factory, logger log.Logger) Publisher { - logger = log.NewContext(logger).With("component", "Static Publisher") - endpoints := []endpoint.Endpoint{} - for _, instance := range instances { - e, _, err := factory(instance) // never close - if err != nil { - logger.Log("instance", instance, "err", err) - continue - } - endpoints = append(endpoints, e) - } - return Publisher{publisher: fixed.NewPublisher(endpoints)} -} - -// Endpoints implements Publisher. -func (p Publisher) Endpoints() ([]endpoint.Endpoint, error) { - return p.publisher.Endpoints() -} diff --git a/loadbalancer/static/publisher_test.go b/loadbalancer/static/publisher_test.go deleted file mode 100644 index 42cd58190..000000000 --- a/loadbalancer/static/publisher_test.go +++ /dev/null @@ -1,39 +0,0 @@ -package static_test - -import ( - "fmt" - "io" - "testing" - - "golang.org/x/net/context" - - "github.com/go-kit/kit/endpoint" - "github.com/go-kit/kit/loadbalancer/static" - "github.com/go-kit/kit/log" -) - -func TestStatic(t *testing.T) { - var ( - instances = []string{"foo", "bar", "baz"} - endpoints = map[string]endpoint.Endpoint{ - "foo": func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, - "bar": func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, - "baz": func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, - } - factory = func(instance string) (endpoint.Endpoint, io.Closer, error) { - if e, ok := endpoints[instance]; ok { - return e, nil, nil - } - return nil, nil, fmt.Errorf("%s: not found", instance) - } - ) - p := static.NewPublisher(instances, factory, log.NewNopLogger()) - have, err := p.Endpoints() - if err != nil { - t.Fatal(err) - } - want := []endpoint.Endpoint{endpoints["foo"], endpoints["bar"], endpoints["baz"]} - if fmt.Sprint(want) != fmt.Sprint(have) { - t.Fatalf("want %v, have %v", want, have) - } -} diff --git a/loadbalancer/zk/publisher.go b/loadbalancer/zk/publisher.go deleted file mode 100644 index da528327e..000000000 --- a/loadbalancer/zk/publisher.go +++ /dev/null @@ -1,83 +0,0 @@ -package zk - -import ( - "github.com/go-kit/kit/endpoint" - "github.com/go-kit/kit/loadbalancer" - "github.com/go-kit/kit/log" - "github.com/samuel/go-zookeeper/zk" -) - -// Publisher yield endpoints stored in a certain ZooKeeper path. Any kind of -// change in that path is watched and will update the Publisher endpoints. -type Publisher struct { - client Client - path string - cache *loadbalancer.EndpointCache - logger log.Logger - quit chan struct{} -} - -// NewPublisher returns a ZooKeeper publisher. ZooKeeper will start watching the -// given path for changes and update the Publisher endpoints. -func NewPublisher(c Client, path string, f loadbalancer.Factory, logger log.Logger) (*Publisher, error) { - p := &Publisher{ - client: c, - path: path, - cache: loadbalancer.NewEndpointCache(f, logger), - logger: logger, - quit: make(chan struct{}), - } - - err := p.client.CreateParentNodes(p.path) - if err != nil { - return nil, err - } - - // initial node retrieval and cache fill - instances, eventc, err := p.client.GetEntries(p.path) - if err != nil { - logger.Log("path", p.path, "msg", "failed to retrieve entries", "err", err) - return nil, err - } - logger.Log("path", p.path, "instances", len(instances)) - p.cache.Replace(instances) - - // handle incoming path updates - go p.loop(eventc) - - return p, nil -} - -func (p *Publisher) loop(eventc <-chan zk.Event) { - var ( - instances []string - err error - ) - for { - select { - case <-eventc: - // we received a path update notification, call GetEntries to - // retrieve child node data and set new watch as zk watches are one - // time triggers - instances, eventc, err = p.client.GetEntries(p.path) - if err != nil { - p.logger.Log("path", p.path, "msg", "failed to retrieve entries", "err", err) - continue - } - p.logger.Log("path", p.path, "instances", len(instances)) - p.cache.Replace(instances) - case <-p.quit: - return - } - } -} - -// Endpoints implements the Publisher interface. -func (p *Publisher) Endpoints() ([]endpoint.Endpoint, error) { - return p.cache.Endpoints() -} - -// Stop terminates the Publisher. -func (p *Publisher) Stop() { - close(p.quit) -} diff --git a/loadbalancer/zk/publisher_test.go b/loadbalancer/zk/publisher_test.go deleted file mode 100644 index da3d619d4..000000000 --- a/loadbalancer/zk/publisher_test.go +++ /dev/null @@ -1,116 +0,0 @@ -package zk - -import ( - "testing" - "time" -) - -func TestPublisher(t *testing.T) { - client := newFakeClient() - - p, err := NewPublisher(client, path, newFactory(""), logger) - if err != nil { - t.Fatalf("failed to create new publisher: %v", err) - } - defer p.Stop() - - if _, err := p.Endpoints(); err != nil { - t.Fatal(err) - } -} - -func TestBadFactory(t *testing.T) { - client := newFakeClient() - - p, err := NewPublisher(client, path, newFactory("kaboom"), logger) - if err != nil { - t.Fatalf("failed to create new publisher: %v", err) - } - defer p.Stop() - - // instance1 came online - client.AddService(path+"/instance1", "kaboom") - - // instance2 came online - client.AddService(path+"/instance2", "zookeeper_node_data") - - if err = asyncTest(100*time.Millisecond, 1, p); err != nil { - t.Error(err) - } -} - -func TestServiceUpdate(t *testing.T) { - client := newFakeClient() - - p, err := NewPublisher(client, path, newFactory(""), logger) - if err != nil { - t.Fatalf("failed to create new publisher: %v", err) - } - defer p.Stop() - - endpoints, err := p.Endpoints() - if err != nil { - t.Fatal(err) - } - - if want, have := 0, len(endpoints); want != have { - t.Errorf("want %d, have %d", want, have) - } - - // instance1 came online - client.AddService(path+"/instance1", "zookeeper_node_data") - - // instance2 came online - client.AddService(path+"/instance2", "zookeeper_node_data2") - - // we should have 2 instances - if err = asyncTest(100*time.Millisecond, 2, p); err != nil { - t.Error(err) - } - - // watch triggers an error... - client.SendErrorOnWatch() - - // test if error was consumed - if err = client.ErrorIsConsumed(100 * time.Millisecond); err != nil { - t.Error(err) - } - - // instance3 came online - client.AddService(path+"/instance3", "zookeeper_node_data3") - - // we should have 3 instances - if err = asyncTest(100*time.Millisecond, 3, p); err != nil { - t.Error(err) - } - - // instance1 goes offline - client.RemoveService(path + "/instance1") - - // instance2 goes offline - client.RemoveService(path + "/instance2") - - // we should have 1 instance - if err = asyncTest(100*time.Millisecond, 1, p); err != nil { - t.Error(err) - } -} - -func TestBadPublisherCreate(t *testing.T) { - client := newFakeClient() - client.SendErrorOnWatch() - p, err := NewPublisher(client, path, newFactory(""), logger) - if err == nil { - t.Error("expected error on new publisher") - } - if p != nil { - t.Error("expected publisher not to be created") - } - p, err = NewPublisher(client, "BadPath", newFactory(""), logger) - if err == nil { - t.Error("expected error on new publisher") - } - if p != nil { - t.Error("expected publisher not to be created") - } -} diff --git a/sd/cache/benchmark_test.go b/sd/cache/benchmark_test.go new file mode 100644 index 000000000..41f1821f9 --- /dev/null +++ b/sd/cache/benchmark_test.go @@ -0,0 +1,29 @@ +package cache + +import ( + "io" + "testing" + + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/log" +) + +func BenchmarkEndpoints(b *testing.B) { + var ( + ca = make(closer) + cb = make(closer) + cmap = map[string]io.Closer{"a": ca, "b": cb} + factory = func(instance string) (endpoint.Endpoint, io.Closer, error) { return endpoint.Nop, cmap[instance], nil } + c = New(factory, log.NewNopLogger()) + ) + + b.ReportAllocs() + + c.Update([]string{"a", "b"}) + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + c.Endpoints() + } + }) +} diff --git a/sd/cache/cache.go b/sd/cache/cache.go new file mode 100644 index 000000000..82af86b51 --- /dev/null +++ b/sd/cache/cache.go @@ -0,0 +1,96 @@ +package cache + +import ( + "io" + "sort" + "sync" + "sync/atomic" + + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/log" + "github.com/go-kit/kit/sd" +) + +// Cache collects the most recent set of endpoints from a service discovery +// system via a subscriber, and makes them available to consumers. Cache is +// meant to be embedded inside of a concrete subscriber, and can serve Service +// invocations directly. +type Cache struct { + mtx sync.RWMutex + factory sd.Factory + cache map[string]endpointCloser + slice atomic.Value // []endpoint.Endpoint + logger log.Logger +} + +type endpointCloser struct { + endpoint.Endpoint + io.Closer +} + +// New returns a new, empty endpoint cache. +func New(factory sd.Factory, logger log.Logger) *Cache { + return &Cache{ + factory: factory, + cache: map[string]endpointCloser{}, + logger: logger, + } +} + +// Update should be invoked by clients with a complete set of current instance +// strings whenever that set changes. The cache manufactures new endpoints via +// the factory, closes old endpoints when they disappear, and persists existing +// endpoints if they survive through an update. +func (c *Cache) Update(instances []string) { + c.mtx.Lock() + defer c.mtx.Unlock() + + // Deterministic order (for later). + sort.Strings(instances) + + // Produce the current set of services. + cache := make(map[string]endpointCloser, len(instances)) + for _, instance := range instances { + // If it already exists, just copy it over. + if sc, ok := c.cache[instance]; ok { + cache[instance] = sc + delete(c.cache, instance) + continue + } + + // If it doesn't exist, create it. + service, closer, err := c.factory(instance) + if err != nil { + c.logger.Log("instance", instance, "err", err) + continue + } + cache[instance] = endpointCloser{service, closer} + } + + // Close any leftover endpoints. + for _, sc := range c.cache { + if sc.Closer != nil { + sc.Closer.Close() + } + } + + // Populate the slice of endpoints. + slice := make([]endpoint.Endpoint, 0, len(cache)) + for _, instance := range instances { + // A bad factory may mean an instance is not present. + if _, ok := cache[instance]; !ok { + continue + } + slice = append(slice, cache[instance].Endpoint) + } + + // Swap and trigger GC for old copies. + c.slice.Store(slice) + c.cache = cache +} + +// Endpoints yields the current set of (presumably identical) endpoints, ordered +// lexicographically by the corresponding instance string. +func (c *Cache) Endpoints() []endpoint.Endpoint { + return c.slice.Load().([]endpoint.Endpoint) +} diff --git a/sd/cache/cache_test.go b/sd/cache/cache_test.go new file mode 100644 index 000000000..be9abafce --- /dev/null +++ b/sd/cache/cache_test.go @@ -0,0 +1,91 @@ +package cache + +import ( + "errors" + "io" + "testing" + "time" + + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/log" +) + +func TestCache(t *testing.T) { + var ( + ca = make(closer) + cb = make(closer) + c = map[string]io.Closer{"a": ca, "b": cb} + f = func(instance string) (endpoint.Endpoint, io.Closer, error) { return endpoint.Nop, c[instance], nil } + cache = New(f, log.NewNopLogger()) + ) + + // Populate + cache.Update([]string{"a", "b"}) + select { + case <-ca: + t.Errorf("endpoint a closed, not good") + case <-cb: + t.Errorf("endpoint b closed, not good") + case <-time.After(time.Millisecond): + t.Logf("no closures yet, good") + } + if want, have := 2, len(cache.Endpoints()); want != have { + t.Errorf("want %d, have %d", want, have) + } + + // Duplicate, should be no-op + cache.Update([]string{"a", "b"}) + select { + case <-ca: + t.Errorf("endpoint a closed, not good") + case <-cb: + t.Errorf("endpoint b closed, not good") + case <-time.After(time.Millisecond): + t.Logf("no closures yet, good") + } + if want, have := 2, len(cache.Endpoints()); want != have { + t.Errorf("want %d, have %d", want, have) + } + + // Delete b + go cache.Update([]string{"a"}) + select { + case <-ca: + t.Errorf("endpoint a closed, not good") + case <-cb: + t.Logf("endpoint b closed, good") + case <-time.After(time.Second): + t.Errorf("didn't close the deleted instance in time") + } + if want, have := 1, len(cache.Endpoints()); want != have { + t.Errorf("want %d, have %d", want, have) + } + + // Delete a + go cache.Update([]string{}) + select { + // case <-cb: will succeed, as it's closed + case <-ca: + t.Logf("endpoint a closed, good") + case <-time.After(time.Second): + t.Errorf("didn't close the deleted instance in time") + } + if want, have := 0, len(cache.Endpoints()); want != have { + t.Errorf("want %d, have %d", want, have) + } +} + +func TestBadFactory(t *testing.T) { + cache := New(func(string) (endpoint.Endpoint, io.Closer, error) { + return nil, nil, errors.New("bad factory") + }, log.NewNopLogger()) + + cache.Update([]string{"foo:1234", "bar:5678"}) + if want, have := 0, len(cache.Endpoints()); want != have { + t.Errorf("want %d, have %d", want, have) + } +} + +type closer chan struct{} + +func (c closer) Close() error { close(c); return nil } diff --git a/sd/consul/client.go b/sd/consul/client.go new file mode 100644 index 000000000..4d88ce3df --- /dev/null +++ b/sd/consul/client.go @@ -0,0 +1,37 @@ +package consul + +import consul "github.com/hashicorp/consul/api" + +// Client is a wrapper around the Consul API. +type Client interface { + // Register a service with the local agent. + Register(r *consul.AgentServiceRegistration) error + + // Deregister a service with the local agent. + Deregister(r *consul.AgentServiceRegistration) error + + // Service + Service(service, tag string, passingOnly bool, queryOpts *consul.QueryOptions) ([]*consul.ServiceEntry, *consul.QueryMeta, error) +} + +type client struct { + consul *consul.Client +} + +// NewClient returns an implementation of the Client interface, wrapping a +// concrete Consul client. +func NewClient(c *consul.Client) Client { + return &client{consul: c} +} + +func (c *client) Register(r *consul.AgentServiceRegistration) error { + return c.consul.Agent().ServiceRegister(r) +} + +func (c *client) Deregister(r *consul.AgentServiceRegistration) error { + return c.consul.Agent().ServiceDeregister(r.ID) +} + +func (c *client) Service(service, tag string, passingOnly bool, queryOpts *consul.QueryOptions) ([]*consul.ServiceEntry, *consul.QueryMeta, error) { + return c.consul.Health().Service(service, tag, passingOnly, queryOpts) +} diff --git a/sd/consul/client_test.go b/sd/consul/client_test.go new file mode 100644 index 000000000..cf02aea1d --- /dev/null +++ b/sd/consul/client_test.go @@ -0,0 +1,156 @@ +package consul + +import ( + "errors" + "io" + "reflect" + "testing" + + stdconsul "github.com/hashicorp/consul/api" + "golang.org/x/net/context" + + "github.com/go-kit/kit/endpoint" +) + +func TestClientRegistration(t *testing.T) { + c := newTestClient(nil) + + services, _, err := c.Service(testRegistration.Name, "", true, &stdconsul.QueryOptions{}) + if err != nil { + t.Error(err) + } + if want, have := 0, len(services); want != have { + t.Errorf("want %d, have %d", want, have) + } + + if err := c.Register(testRegistration); err != nil { + t.Error(err) + } + + if err := c.Register(testRegistration); err == nil { + t.Errorf("want error, have %v", err) + } + + services, _, err = c.Service(testRegistration.Name, "", true, &stdconsul.QueryOptions{}) + if err != nil { + t.Error(err) + } + if want, have := 1, len(services); want != have { + t.Errorf("want %d, have %d", want, have) + } + + if err := c.Deregister(testRegistration); err != nil { + t.Error(err) + } + + if err := c.Deregister(testRegistration); err == nil { + t.Errorf("want error, have %v", err) + } + + services, _, err = c.Service(testRegistration.Name, "", true, &stdconsul.QueryOptions{}) + if err != nil { + t.Error(err) + } + if want, have := 0, len(services); want != have { + t.Errorf("want %d, have %d", want, have) + } +} + +type testClient struct { + entries []*stdconsul.ServiceEntry +} + +func newTestClient(entries []*stdconsul.ServiceEntry) *testClient { + return &testClient{ + entries: entries, + } +} + +var _ Client = &testClient{} + +func (c *testClient) Service(service, tag string, _ bool, opts *stdconsul.QueryOptions) ([]*stdconsul.ServiceEntry, *stdconsul.QueryMeta, error) { + var results []*stdconsul.ServiceEntry + + for _, entry := range c.entries { + if entry.Service.Service != service { + continue + } + if tag != "" { + tagMap := map[string]struct{}{} + + for _, t := range entry.Service.Tags { + tagMap[t] = struct{}{} + } + + if _, ok := tagMap[tag]; !ok { + continue + } + } + + results = append(results, entry) + } + + return results, &stdconsul.QueryMeta{}, nil +} + +func (c *testClient) Register(r *stdconsul.AgentServiceRegistration) error { + toAdd := registration2entry(r) + + for _, entry := range c.entries { + if reflect.DeepEqual(*entry, *toAdd) { + return errors.New("duplicate") + } + } + + c.entries = append(c.entries, toAdd) + return nil +} + +func (c *testClient) Deregister(r *stdconsul.AgentServiceRegistration) error { + toDelete := registration2entry(r) + + var newEntries []*stdconsul.ServiceEntry + for _, entry := range c.entries { + if reflect.DeepEqual(*entry, *toDelete) { + continue + } + newEntries = append(newEntries, entry) + } + if len(newEntries) == len(c.entries) { + return errors.New("not found") + } + + c.entries = newEntries + return nil +} + +func registration2entry(r *stdconsul.AgentServiceRegistration) *stdconsul.ServiceEntry { + return &stdconsul.ServiceEntry{ + Node: &stdconsul.Node{ + Node: "some-node", + Address: r.Address, + }, + Service: &stdconsul.AgentService{ + ID: r.ID, + Service: r.Name, + Tags: r.Tags, + Port: r.Port, + Address: r.Address, + }, + // Checks ignored + } +} + +func testFactory(instance string) (endpoint.Endpoint, io.Closer, error) { + return func(context.Context, interface{}) (interface{}, error) { + return instance, nil + }, nil, nil +} + +var testRegistration = &stdconsul.AgentServiceRegistration{ + ID: "my-id", + Name: "my-name", + Tags: []string{"my-tag-1", "my-tag-2"}, + Port: 12345, + Address: "my-address", +} diff --git a/sd/consul/integration_test.go b/sd/consul/integration_test.go new file mode 100644 index 000000000..495adad61 --- /dev/null +++ b/sd/consul/integration_test.go @@ -0,0 +1,86 @@ +// +build integration + +package consul + +import ( + "io" + "os" + "testing" + "time" + + "github.com/go-kit/kit/log" + "github.com/go-kit/kit/service" + stdconsul "github.com/hashicorp/consul/api" +) + +func TestIntegration(t *testing.T) { + // Connect to Consul. + // docker run -p 8500:8500 progrium/consul -server -bootstrap + consulAddr := os.Getenv("CONSUL_ADDRESS") + if consulAddr == "" { + t.Fatal("CONSUL_ADDRESS is not set") + } + stdClient, err := stdconsul.NewClient(&stdconsul.Config{ + Address: consulAddr, + }) + if err != nil { + t.Fatal(err) + } + client := NewClient(stdClient) + logger := log.NewLogfmtLogger(os.Stderr) + + // Produce a fake service registration. + r := &stdconsul.AgentServiceRegistration{ + ID: "my-service-ID", + Name: "my-service-name", + Tags: []string{"alpha", "beta"}, + Port: 12345, + Address: "my-address", + EnableTagOverride: false, + // skipping check(s) + } + + // Build a subscriber on r.Name + r.Tags. + factory := func(instance string) (service.Service, io.Closer, error) { + t.Logf("factory invoked for %q", instance) + return service.Fixed{}, nil, nil + } + subscriber, err := NewSubscriber( + client, + factory, + log.NewContext(logger).With("component", "subscriber"), + r.Name, + r.Tags, + true, + ) + if err != nil { + t.Fatal(err) + } + + time.Sleep(time.Second) + + // Before we publish, we should have no services. + services, err := subscriber.Services() + if err != nil { + t.Error(err) + } + if want, have := 0, len(services); want != have { + t.Errorf("want %d, have %d", want, have) + } + + // Build a registrar for r. + registrar := NewRegistrar(client, r, log.NewContext(logger).With("component", "registrar")) + registrar.Register() + defer registrar.Deregister() + + time.Sleep(time.Second) + + // Now we should have one active service. + services, err = subscriber.Services() + if err != nil { + t.Error(err) + } + if want, have := 1, len(services); want != have { + t.Errorf("want %d, have %d", want, have) + } +} diff --git a/sd/consul/registrar.go b/sd/consul/registrar.go new file mode 100644 index 000000000..e89fef696 --- /dev/null +++ b/sd/consul/registrar.go @@ -0,0 +1,44 @@ +package consul + +import ( + "fmt" + + stdconsul "github.com/hashicorp/consul/api" + + "github.com/go-kit/kit/log" +) + +// Registrar registers service instance liveness information to Consul. +type Registrar struct { + client Client + registration *stdconsul.AgentServiceRegistration + logger log.Logger +} + +// NewRegistrar returns a Consul Registrar acting on the provided catalog +// registration. +func NewRegistrar(client Client, r *stdconsul.AgentServiceRegistration, logger log.Logger) *Registrar { + return &Registrar{ + client: client, + registration: r, + logger: log.NewContext(logger).With("service", r.Name, "tags", fmt.Sprint(r.Tags), "address", r.Address), + } +} + +// Register implements sd.Registrar interface. +func (p *Registrar) Register() { + if err := p.client.Register(p.registration); err != nil { + p.logger.Log("err", err) + } else { + p.logger.Log("action", "register") + } +} + +// Deregister implements sd.Registrar interface. +func (p *Registrar) Deregister() { + if err := p.client.Deregister(p.registration); err != nil { + p.logger.Log("err", err) + } else { + p.logger.Log("action", "deregister") + } +} diff --git a/sd/consul/registrar_test.go b/sd/consul/registrar_test.go new file mode 100644 index 000000000..edc772327 --- /dev/null +++ b/sd/consul/registrar_test.go @@ -0,0 +1,27 @@ +package consul + +import ( + "testing" + + stdconsul "github.com/hashicorp/consul/api" + + "github.com/go-kit/kit/log" +) + +func TestRegistrar(t *testing.T) { + client := newTestClient([]*stdconsul.ServiceEntry{}) + p := NewRegistrar(client, testRegistration, log.NewNopLogger()) + if want, have := 0, len(client.entries); want != have { + t.Errorf("want %d, have %d", want, have) + } + + p.Register() + if want, have := 1, len(client.entries); want != have { + t.Errorf("want %d, have %d", want, have) + } + + p.Deregister() + if want, have := 0, len(client.entries); want != have { + t.Errorf("want %d, have %d", want, have) + } +} diff --git a/sd/consul/subscriber.go b/sd/consul/subscriber.go new file mode 100644 index 000000000..ee3ae34bb --- /dev/null +++ b/sd/consul/subscriber.go @@ -0,0 +1,166 @@ +package consul + +import ( + "fmt" + "io" + + consul "github.com/hashicorp/consul/api" + + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/log" + "github.com/go-kit/kit/sd" + "github.com/go-kit/kit/sd/cache" +) + +const defaultIndex = 0 + +// Subscriber yields endpoints for a service in Consul. Updates to the service +// are watched and will update the Subscriber endpoints. +type Subscriber struct { + cache *cache.Cache + client Client + logger log.Logger + service string + tags []string + passingOnly bool + endpointsc chan []endpoint.Endpoint + quitc chan struct{} +} + +var _ sd.Subscriber = &Subscriber{} + +// NewSubscriber returns a Consul subscriber which returns endpoints for the +// requested service. It only returns instances for which all of the passed tags +// are present. +func NewSubscriber(client Client, factory sd.Factory, logger log.Logger, service string, tags []string, passingOnly bool) *Subscriber { + s := &Subscriber{ + cache: cache.New(factory, logger), + client: client, + logger: log.NewContext(logger).With("service", service, "tags", fmt.Sprint(tags)), + service: service, + tags: tags, + passingOnly: passingOnly, + quitc: make(chan struct{}), + } + + instances, index, err := s.getInstances(defaultIndex, nil) + if err == nil { + s.logger.Log("instances", len(instances)) + } else { + s.logger.Log("err", err) + } + + s.cache.Update(instances) + go s.loop(index) + return s +} + +// Endpoints implements the Subscriber interface. +func (s *Subscriber) Endpoints() ([]endpoint.Endpoint, error) { + return s.cache.Endpoints(), nil +} + +// Stop terminates the subscriber. +func (s *Subscriber) Stop() { + close(s.quitc) +} + +func (s *Subscriber) loop(lastIndex uint64) { + var ( + instances []string + err error + ) + for { + instances, lastIndex, err = s.getInstances(lastIndex, s.quitc) + switch { + case err == io.EOF: + return // stopped via quitc + case err != nil: + s.logger.Log("err", err) + default: + s.cache.Update(instances) + } + } +} + +func (s *Subscriber) getInstances(lastIndex uint64, interruptc chan struct{}) ([]string, uint64, error) { + tag := "" + if len(s.tags) > 0 { + tag = s.tags[0] + } + + // Consul doesn't support more than one tag in its service query method. + // https://github.com/hashicorp/consul/issues/294 + // Hashi suggest prepared queries, but they don't support blocking. + // https://www.consul.io/docs/agent/http/query.html#execute + // If we want blocking for efficiency, we must filter tags manually. + + type response struct { + instances []string + index uint64 + } + + var ( + errc = make(chan error, 1) + resc = make(chan response, 1) + ) + + go func() { + entries, meta, err := s.client.Service(s.service, tag, s.passingOnly, &consul.QueryOptions{ + WaitIndex: lastIndex, + }) + if err != nil { + errc <- err + return + } + if len(s.tags) > 1 { + entries = filterEntries(entries, s.tags[1:]...) + } + resc <- response{ + instances: makeInstances(entries), + index: meta.LastIndex, + } + }() + + select { + case err := <-errc: + return nil, 0, err + case res := <-resc: + return res.instances, res.index, nil + case <-interruptc: + return nil, 0, io.EOF + } +} + +func filterEntries(entries []*consul.ServiceEntry, tags ...string) []*consul.ServiceEntry { + var es []*consul.ServiceEntry + +ENTRIES: + for _, entry := range entries { + ts := make(map[string]struct{}, len(entry.Service.Tags)) + for _, tag := range entry.Service.Tags { + ts[tag] = struct{}{} + } + + for _, tag := range tags { + if _, ok := ts[tag]; !ok { + continue ENTRIES + } + } + es = append(es, entry) + } + + return es +} + +func makeInstances(entries []*consul.ServiceEntry) []string { + instances := make([]string, len(entries)) + for i, entry := range entries { + addr := entry.Node.Address + if entry.Service.Address != "" { + addr = entry.Service.Address + } + instances[i] = fmt.Sprintf("%s:%d", addr, entry.Service.Port) + } + return instances +} diff --git a/sd/consul/subscriber_test.go b/sd/consul/subscriber_test.go new file mode 100644 index 000000000..f581216eb --- /dev/null +++ b/sd/consul/subscriber_test.go @@ -0,0 +1,138 @@ +package consul + +import ( + "testing" + + consul "github.com/hashicorp/consul/api" + "golang.org/x/net/context" + + "github.com/go-kit/kit/log" +) + +var consulState = []*consul.ServiceEntry{ + { + Node: &consul.Node{ + Address: "10.0.0.0", + Node: "app00.local", + }, + Service: &consul.AgentService{ + ID: "search-api-0", + Port: 8000, + Service: "search", + Tags: []string{ + "api", + "v1", + }, + }, + }, + { + Node: &consul.Node{ + Address: "10.0.0.1", + Node: "app01.local", + }, + Service: &consul.AgentService{ + ID: "search-api-1", + Port: 8001, + Service: "search", + Tags: []string{ + "api", + "v2", + }, + }, + }, + { + Node: &consul.Node{ + Address: "10.0.0.1", + Node: "app01.local", + }, + Service: &consul.AgentService{ + Address: "10.0.0.10", + ID: "search-db-0", + Port: 9000, + Service: "search", + Tags: []string{ + "db", + }, + }, + }, +} + +func TestSubscriber(t *testing.T) { + var ( + logger = log.NewNopLogger() + client = newTestClient(consulState) + ) + + s := NewSubscriber(client, testFactory, logger, "search", []string{"api"}, true) + defer s.Stop() + + endpoints, err := s.Endpoints() + if err != nil { + t.Fatal(err) + } + + if want, have := 2, len(endpoints); want != have { + t.Errorf("want %d, have %d", want, have) + } +} + +func TestSubscriberNoService(t *testing.T) { + var ( + logger = log.NewNopLogger() + client = newTestClient(consulState) + ) + + s := NewSubscriber(client, testFactory, logger, "feed", []string{}, true) + defer s.Stop() + + endpoints, err := s.Endpoints() + if err != nil { + t.Fatal(err) + } + + if want, have := 0, len(endpoints); want != have { + t.Fatalf("want %d, have %d", want, have) + } +} + +func TestSubscriberWithTags(t *testing.T) { + var ( + logger = log.NewNopLogger() + client = newTestClient(consulState) + ) + + s := NewSubscriber(client, testFactory, logger, "search", []string{"api", "v2"}, true) + defer s.Stop() + + endpoints, err := s.Endpoints() + if err != nil { + t.Fatal(err) + } + + if want, have := 1, len(endpoints); want != have { + t.Fatalf("want %d, have %d", want, have) + } +} + +func TestSubscriberAddressOverride(t *testing.T) { + s := NewSubscriber(newTestClient(consulState), testFactory, log.NewNopLogger(), "search", []string{"db"}, true) + defer s.Stop() + + endpoints, err := s.Endpoints() + if err != nil { + t.Fatal(err) + } + + if want, have := 1, len(endpoints); want != have { + t.Fatalf("want %d, have %d", want, have) + } + + response, err := endpoints[0](context.Background(), struct{}{}) + if err != nil { + t.Fatal(err) + } + + if want, have := "10.0.0.10:9000", response.(string); want != have { + t.Errorf("want %q, have %q", want, have) + } +} diff --git a/sd/dnssrv/lookup.go b/sd/dnssrv/lookup.go new file mode 100644 index 000000000..9d46ea6e4 --- /dev/null +++ b/sd/dnssrv/lookup.go @@ -0,0 +1,7 @@ +package dnssrv + +import "net" + +// Lookup is a function that resolves a DNS SRV record to multiple addresses. +// It has the same signature as net.LookupSRV. +type Lookup func(service, proto, name string) (cname string, addrs []*net.SRV, err error) diff --git a/sd/dnssrv/subscriber.go b/sd/dnssrv/subscriber.go new file mode 100644 index 000000000..422fdaa76 --- /dev/null +++ b/sd/dnssrv/subscriber.go @@ -0,0 +1,100 @@ +package dnssrv + +import ( + "fmt" + "net" + "time" + + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/log" + "github.com/go-kit/kit/sd" + "github.com/go-kit/kit/sd/cache" +) + +// Subscriber yields endpoints taken from the named DNS SRV record. The name is +// resolved on a fixed schedule. Priorities and weights are ignored. +type Subscriber struct { + name string + cache *cache.Cache + logger log.Logger + quit chan struct{} +} + +// NewSubscriber returns a DNS SRV subscriber. +func NewSubscriber( + name string, + ttl time.Duration, + factory sd.Factory, + logger log.Logger, +) *Subscriber { + return NewSubscriberDetailed(name, time.NewTicker(ttl), net.LookupSRV, factory, logger) +} + +// NewSubscriberDetailed is the same as NewSubscriber, but allows users to +// provide an explicit lookup refresh ticker instead of a TTL, and specify the +// lookup function instead of using net.LookupSRV. +func NewSubscriberDetailed( + name string, + refresh *time.Ticker, + lookup Lookup, + factory sd.Factory, + logger log.Logger, +) *Subscriber { + p := &Subscriber{ + name: name, + cache: cache.New(factory, logger), + logger: logger, + quit: make(chan struct{}), + } + + instances, err := p.resolve(lookup) + if err == nil { + logger.Log("name", name, "instances", len(instances)) + } else { + logger.Log("name", name, "err", err) + } + p.cache.Update(instances) + + go p.loop(refresh, lookup) + return p +} + +// Stop terminates the Subscriber. +func (p *Subscriber) Stop() { + close(p.quit) +} + +func (p *Subscriber) loop(t *time.Ticker, lookup Lookup) { + defer t.Stop() + for { + select { + case <-t.C: + instances, err := p.resolve(lookup) + if err != nil { + p.logger.Log("name", p.name, "err", err) + continue // don't replace potentially-good with bad + } + p.cache.Update(instances) + + case <-p.quit: + return + } + } +} + +// Endpoints implements the Subscriber interface. +func (p *Subscriber) Endpoints() ([]endpoint.Endpoint, error) { + return p.cache.Endpoints(), nil +} + +func (p *Subscriber) resolve(lookup Lookup) ([]string, error) { + _, addrs, err := lookup("", "", p.name) + if err != nil { + return []string{}, err + } + instances := make([]string, len(addrs)) + for i, addr := range addrs { + instances[i] = net.JoinHostPort(addr.Target, fmt.Sprint(addr.Port)) + } + return instances, nil +} diff --git a/sd/dnssrv/subscriber_test.go b/sd/dnssrv/subscriber_test.go new file mode 100644 index 000000000..5a9036ad6 --- /dev/null +++ b/sd/dnssrv/subscriber_test.go @@ -0,0 +1,85 @@ +package dnssrv + +import ( + "io" + "net" + "sync/atomic" + "testing" + "time" + + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/log" +) + +func TestRefresh(t *testing.T) { + name := "some.service.internal" + + ticker := time.NewTicker(time.Second) + ticker.Stop() + tickc := make(chan time.Time) + ticker.C = tickc + + var lookups uint64 + records := []*net.SRV{} + lookup := func(service, proto, name string) (string, []*net.SRV, error) { + t.Logf("lookup(%q, %q, %q)", service, proto, name) + atomic.AddUint64(&lookups, 1) + return "cname", records, nil + } + + var generates uint64 + factory := func(instance string) (endpoint.Endpoint, io.Closer, error) { + t.Logf("factory(%q)", instance) + atomic.AddUint64(&generates, 1) + return endpoint.Nop, nopCloser{}, nil + } + + subscriber := NewSubscriberDetailed(name, ticker, lookup, factory, log.NewNopLogger()) + defer subscriber.Stop() + + // First lookup, empty + endpoints, err := subscriber.Endpoints() + if err != nil { + t.Error(err) + } + if want, have := 0, len(endpoints); want != have { + t.Errorf("want %d, have %d", want, have) + } + if want, have := uint64(1), atomic.LoadUint64(&lookups); want != have { + t.Errorf("want %d, have %d", want, have) + } + if want, have := uint64(0), atomic.LoadUint64(&generates); want != have { + t.Errorf("want %d, have %d", want, have) + } + + // Load some records and lookup again + records = []*net.SRV{ + &net.SRV{Target: "1.0.0.1", Port: 1001}, + &net.SRV{Target: "1.0.0.2", Port: 1002}, + &net.SRV{Target: "1.0.0.3", Port: 1003}, + } + tickc <- time.Now() + + // There is a race condition where the subscriber.Endpoints call below + // invokes the cache before it is updated by the tick above. + // TODO(pb): solve by running the read through the loop goroutine. + time.Sleep(100 * time.Millisecond) + + endpoints, err = subscriber.Endpoints() + if err != nil { + t.Error(err) + } + if want, have := 3, len(endpoints); want != have { + t.Errorf("want %d, have %d", want, have) + } + if want, have := uint64(2), atomic.LoadUint64(&lookups); want != have { + t.Errorf("want %d, have %d", want, have) + } + if want, have := uint64(len(records)), atomic.LoadUint64(&generates); want != have { + t.Errorf("want %d, have %d", want, have) + } +} + +type nopCloser struct{} + +func (nopCloser) Close() error { return nil } diff --git a/sd/doc.go b/sd/doc.go new file mode 100644 index 000000000..b10d96fd7 --- /dev/null +++ b/sd/doc.go @@ -0,0 +1,5 @@ +// Package sd provides utilities related to service discovery. That includes +// subscribing to service discovery systems in order to reach remote instances, +// and publishing to service discovery systems to make an instance available. +// Implementations are provided for most common systems. +package sd diff --git a/loadbalancer/etcd/client.go b/sd/etcd/client.go similarity index 95% rename from loadbalancer/etcd/client.go rename to sd/etcd/client.go index 9abfd3415..b9e2904a0 100644 --- a/loadbalancer/etcd/client.go +++ b/sd/etcd/client.go @@ -16,6 +16,7 @@ import ( type Client interface { // GetEntries will query the given prefix in etcd and returns a set of entries. GetEntries(prefix string) ([]string, error) + // WatchPrefix starts watching every change for given prefix in etcd. When an // change is detected it will populate the responseChan when an *etcd.Response. WatchPrefix(prefix string, responseChan chan *etcd.Response) @@ -26,6 +27,7 @@ type client struct { ctx context.Context } +// ClientOptions defines options for the etcd client. type ClientOptions struct { Cert string Key string @@ -39,16 +41,13 @@ type ClientOptions struct { // It will return an error if a connection to the cluster cannot be made. // The parameter machines needs to be a full URL with schemas. // e.g. "http://localhost:2379" will work, but "localhost:2379" will not. -func NewClient(ctx context.Context, machines []string, options *ClientOptions) (Client, error) { +func NewClient(ctx context.Context, machines []string, options ClientOptions) (Client, error) { var ( c etcd.KeysAPI err error caCertCt []byte tlsCert tls.Certificate ) - if options == nil { - options = &ClientOptions{} - } if options.Cert != "" && options.Key != "" { tlsCert, err = tls.LoadX509KeyPair(options.Cert, options.Key) @@ -101,6 +100,7 @@ func NewClient(ctx context.Context, machines []string, options *ClientOptions) ( } c = etcd.NewKeysAPI(ce) } + return &client{c, ctx}, nil } diff --git a/sd/etcd/subscriber.go b/sd/etcd/subscriber.go new file mode 100644 index 000000000..1d579eb21 --- /dev/null +++ b/sd/etcd/subscriber.go @@ -0,0 +1,74 @@ +package etcd + +import ( + etcd "github.com/coreos/etcd/client" + + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/log" + "github.com/go-kit/kit/sd" + "github.com/go-kit/kit/sd/cache" +) + +// Subscriber yield endpoints stored in a certain etcd keyspace. Any kind of +// change in that keyspace is watched and will update the Subscriber endpoints. +type Subscriber struct { + client Client + prefix string + cache *cache.Cache + logger log.Logger + quitc chan struct{} +} + +var _ sd.Subscriber = &Subscriber{} + +// NewSubscriber returns an etcd subscriber. It will start watching the given +// prefix for changes, and update the endpoints. +func NewSubscriber(c Client, prefix string, factory sd.Factory, logger log.Logger) (*Subscriber, error) { + s := &Subscriber{ + client: c, + prefix: prefix, + cache: cache.New(factory, logger), + logger: logger, + quitc: make(chan struct{}), + } + + instances, err := s.client.GetEntries(s.prefix) + if err == nil { + logger.Log("prefix", s.prefix, "instances", len(instances)) + } else { + logger.Log("prefix", s.prefix, "err", err) + } + s.cache.Update(instances) + + go s.loop() + return s, nil +} + +func (s *Subscriber) loop() { + responseChan := make(chan *etcd.Response) + go s.client.WatchPrefix(s.prefix, responseChan) + for { + select { + case <-responseChan: + instances, err := s.client.GetEntries(s.prefix) + if err != nil { + s.logger.Log("msg", "failed to retrieve entries", "err", err) + continue + } + s.cache.Update(instances) + + case <-s.quitc: + return + } + } +} + +// Endpoints implements the Subscriber interface. +func (s *Subscriber) Endpoints() ([]endpoint.Endpoint, error) { + return s.cache.Endpoints(), nil +} + +// Stop terminates the Subscriber. +func (s *Subscriber) Stop() { + close(s.quitc) +} diff --git a/loadbalancer/etcd/publisher_test.go b/sd/etcd/subscriber_test.go similarity index 65% rename from loadbalancer/etcd/publisher_test.go rename to sd/etcd/subscriber_test.go index 8d1f51770..0073e1e97 100644 --- a/loadbalancer/etcd/publisher_test.go +++ b/sd/etcd/subscriber_test.go @@ -1,4 +1,4 @@ -package etcd_test +package etcd import ( "errors" @@ -6,10 +6,8 @@ import ( "testing" stdetcd "github.com/coreos/etcd/client" - "golang.org/x/net/context" "github.com/go-kit/kit/endpoint" - kitetcd "github.com/go-kit/kit/loadbalancer/etcd" "github.com/go-kit/kit/log" ) @@ -26,34 +24,27 @@ var ( } ) -func TestPublisher(t *testing.T) { - var ( - logger = log.NewNopLogger() - e = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } - ) - +func TestSubscriber(t *testing.T) { factory := func(string) (endpoint.Endpoint, io.Closer, error) { - return e, nil, nil + return endpoint.Nop, nil, nil } client := &fakeClient{ responses: map[string]*stdetcd.Response{"/foo": fakeResponse}, } - p, err := kitetcd.NewPublisher(client, "/foo", factory, logger) + s, err := NewSubscriber(client, "/foo", factory, log.NewNopLogger()) if err != nil { - t.Fatalf("failed to create new publisher: %v", err) + t.Fatal(err) } - defer p.Stop() + defer s.Stop() - if _, err := p.Endpoints(); err != nil { + if _, err := s.Endpoints(); err != nil { t.Fatal(err) } } func TestBadFactory(t *testing.T) { - logger := log.NewNopLogger() - factory := func(string) (endpoint.Endpoint, io.Closer, error) { return nil, nil, errors.New("kaboom") } @@ -62,19 +53,19 @@ func TestBadFactory(t *testing.T) { responses: map[string]*stdetcd.Response{"/foo": fakeResponse}, } - p, err := kitetcd.NewPublisher(client, "/foo", factory, logger) + s, err := NewSubscriber(client, "/foo", factory, log.NewNopLogger()) if err != nil { - t.Fatalf("failed to create new publisher: %v", err) + t.Fatal(err) } - defer p.Stop() + defer s.Stop() - endpoints, err := p.Endpoints() + endpoints, err := s.Endpoints() if err != nil { t.Fatal(err) } if want, have := 0, len(endpoints); want != have { - t.Errorf("want %q, have %q", want, have) + t.Errorf("want %d, have %d", want, have) } } diff --git a/sd/factory.go b/sd/factory.go new file mode 100644 index 000000000..af99817b4 --- /dev/null +++ b/sd/factory.go @@ -0,0 +1,17 @@ +package sd + +import ( + "io" + + "github.com/go-kit/kit/endpoint" +) + +// Factory is a function that converts an instance string (e.g. host:port) to a +// specific endpoint. Instances that provide multiple endpoints require multiple +// factories. A factory also returns an io.Closer that's invoked when the +// instance goes away and needs to be cleaned up. Factories may return nil +// closers. +// +// Users are expected to provide their own factory functions that assume +// specific transports, or can deduce transports by parsing the instance string. +type Factory func(instance string) (endpoint.Endpoint, io.Closer, error) diff --git a/sd/fixed_subscriber.go b/sd/fixed_subscriber.go new file mode 100644 index 000000000..98fd50323 --- /dev/null +++ b/sd/fixed_subscriber.go @@ -0,0 +1,9 @@ +package sd + +import "github.com/go-kit/kit/endpoint" + +// FixedSubscriber yields a fixed set of services. +type FixedSubscriber []endpoint.Endpoint + +// Endpoints implements Subscriber. +func (s FixedSubscriber) Endpoints() ([]endpoint.Endpoint, error) { return s, nil } diff --git a/sd/lb/balancer.go b/sd/lb/balancer.go new file mode 100644 index 000000000..40aa0ef27 --- /dev/null +++ b/sd/lb/balancer.go @@ -0,0 +1,15 @@ +package lb + +import ( + "errors" + + "github.com/go-kit/kit/endpoint" +) + +// Balancer yields endpoints according to some heuristic. +type Balancer interface { + Endpoint() (endpoint.Endpoint, error) +} + +// ErrNoEndpoints is returned when no qualifying endpoints are available. +var ErrNoEndpoints = errors.New("no endpoints available") diff --git a/sd/lb/doc.go b/sd/lb/doc.go new file mode 100644 index 000000000..82a9516d7 --- /dev/null +++ b/sd/lb/doc.go @@ -0,0 +1,5 @@ +// Package lb deals with client-side load balancing across multiple identical +// instances of services and endpoints. When combined with a service discovery +// system of record, it enables a more decentralized architecture, removing the +// need for separate load balancers like HAProxy. +package lb diff --git a/sd/lb/random.go b/sd/lb/random.go new file mode 100644 index 000000000..78b095673 --- /dev/null +++ b/sd/lb/random.go @@ -0,0 +1,32 @@ +package lb + +import ( + "math/rand" + + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/sd" +) + +// NewRandom returns a load balancer that selects services randomly. +func NewRandom(s sd.Subscriber, seed int64) Balancer { + return &random{ + s: s, + r: rand.New(rand.NewSource(seed)), + } +} + +type random struct { + s sd.Subscriber + r *rand.Rand +} + +func (r *random) Endpoint() (endpoint.Endpoint, error) { + endpoints, err := r.s.Endpoints() + if err != nil { + return nil, err + } + if len(endpoints) <= 0 { + return nil, ErrNoEndpoints + } + return endpoints[r.r.Intn(len(endpoints))], nil +} diff --git a/sd/lb/random_test.go b/sd/lb/random_test.go new file mode 100644 index 000000000..c9b011789 --- /dev/null +++ b/sd/lb/random_test.go @@ -0,0 +1,52 @@ +package lb + +import ( + "math" + "testing" + + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/sd" + "golang.org/x/net/context" +) + +func TestRandom(t *testing.T) { + var ( + n = 7 + endpoints = make([]endpoint.Endpoint, n) + counts = make([]int, n) + seed = int64(12345) + iterations = 1000000 + want = iterations / n + tolerance = want / 100 // 1% + ) + + for i := 0; i < n; i++ { + i0 := i + endpoints[i] = func(context.Context, interface{}) (interface{}, error) { counts[i0]++; return struct{}{}, nil } + } + + subscriber := sd.FixedSubscriber(endpoints) + balancer := NewRandom(subscriber, seed) + + for i := 0; i < iterations; i++ { + endpoint, _ := balancer.Endpoint() + endpoint(context.Background(), struct{}{}) + } + + for i, have := range counts { + delta := int(math.Abs(float64(want - have))) + if delta > tolerance { + t.Errorf("%d: want %d, have %d, delta %d > %d tolerance", i, want, have, delta, tolerance) + } + } +} + +func TestRandomNoEndpoints(t *testing.T) { + subscriber := sd.FixedSubscriber{} + balancer := NewRandom(subscriber, 1415926) + _, err := balancer.Endpoint() + if want, have := ErrNoEndpoints, err; want != have { + t.Errorf("want %v, have %v", want, have) + } + +} diff --git a/loadbalancer/retry.go b/sd/lb/retry.go similarity index 74% rename from loadbalancer/retry.go rename to sd/lb/retry.go index ed931f839..a933eeb0c 100644 --- a/loadbalancer/retry.go +++ b/sd/lb/retry.go @@ -1,4 +1,4 @@ -package loadbalancer +package lb import ( "fmt" @@ -10,16 +10,16 @@ import ( "github.com/go-kit/kit/endpoint" ) -// Retry wraps the load balancer to make it behave like a simple endpoint. +// Retry wraps a service load balancer and returns an endpoint oriented load +// balancer for the specified service method. // Requests to the endpoint will be automatically load balanced via the load // balancer. Requests that return errors will be retried until they succeed, // up to max times, or until the timeout is elapsed, whichever comes first. -func Retry(max int, timeout time.Duration, lb LoadBalancer) endpoint.Endpoint { - if lb == nil { - panic("nil LoadBalancer") +func Retry(max int, timeout time.Duration, b Balancer) endpoint.Endpoint { + if b == nil { + panic("nil Balancer") } - - return func(ctx context.Context, request interface{}) (interface{}, error) { + return func(ctx context.Context, request interface{}) (response interface{}, err error) { var ( newctx, cancel = context.WithTimeout(ctx, timeout) responses = make(chan interface{}, 1) @@ -29,7 +29,7 @@ func Retry(max int, timeout time.Duration, lb LoadBalancer) endpoint.Endpoint { defer cancel() for i := 1; i <= max; i++ { go func() { - e, err := lb.Endpoint() + e, err := b.Endpoint() if err != nil { errs <- err return diff --git a/loadbalancer/retry_test.go b/sd/lb/retry_test.go similarity index 82% rename from loadbalancer/retry_test.go rename to sd/lb/retry_test.go index 004c69051..07b1afdb7 100644 --- a/loadbalancer/retry_test.go +++ b/sd/lb/retry_test.go @@ -1,4 +1,4 @@ -package loadbalancer_test +package lb_test import ( "errors" @@ -8,15 +8,14 @@ import ( "golang.org/x/net/context" "github.com/go-kit/kit/endpoint" - "github.com/go-kit/kit/loadbalancer" - "github.com/go-kit/kit/loadbalancer/fixed" + "github.com/go-kit/kit/sd" + loadbalancer "github.com/go-kit/kit/sd/lb" ) func TestRetryMaxTotalFail(t *testing.T) { var ( - endpoints = []endpoint.Endpoint{} // no endpoints - p = fixed.NewPublisher(endpoints) - lb = loadbalancer.NewRoundRobin(p) + endpoints = sd.FixedSubscriber{} // no endpoints + lb = loadbalancer.NewRoundRobin(endpoints) retry = loadbalancer.Retry(999, time.Second, lb) // lots of retries ctx = context.Background() ) @@ -32,9 +31,13 @@ func TestRetryMaxPartialFail(t *testing.T) { func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error two") }, func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil /* OK */ }, } + subscriber = sd.FixedSubscriber{ + 0: endpoints[0], + 1: endpoints[1], + 2: endpoints[2], + } retries = len(endpoints) - 1 // not quite enough retries - p = fixed.NewPublisher(endpoints) - lb = loadbalancer.NewRoundRobin(p) + lb = loadbalancer.NewRoundRobin(subscriber) ctx = context.Background() ) if _, err := loadbalancer.Retry(retries, time.Second, lb)(ctx, struct{}{}); err == nil { @@ -49,9 +52,13 @@ func TestRetryMaxSuccess(t *testing.T) { func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error two") }, func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil /* OK */ }, } + subscriber = sd.FixedSubscriber{ + 0: endpoints[0], + 1: endpoints[1], + 2: endpoints[2], + } retries = len(endpoints) // exactly enough retries - p = fixed.NewPublisher(endpoints) - lb = loadbalancer.NewRoundRobin(p) + lb = loadbalancer.NewRoundRobin(subscriber) ctx = context.Background() ) if _, err := loadbalancer.Retry(retries, time.Second, lb)(ctx, struct{}{}); err != nil { @@ -64,7 +71,7 @@ func TestRetryTimeout(t *testing.T) { step = make(chan struct{}) e = func(context.Context, interface{}) (interface{}, error) { <-step; return struct{}{}, nil } timeout = time.Millisecond - retry = loadbalancer.Retry(999, timeout, loadbalancer.NewRoundRobin(fixed.NewPublisher([]endpoint.Endpoint{e}))) + retry = loadbalancer.Retry(999, timeout, loadbalancer.NewRoundRobin(sd.FixedSubscriber{0: e})) errs = make(chan error, 1) invoke = func() { _, err := retry(context.Background(), struct{}{}); errs <- err } ) diff --git a/sd/lb/round_robin.go b/sd/lb/round_robin.go new file mode 100644 index 000000000..74b86caea --- /dev/null +++ b/sd/lb/round_robin.go @@ -0,0 +1,34 @@ +package lb + +import ( + "sync/atomic" + + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/sd" +) + +// NewRoundRobin returns a load balancer that returns services in sequence. +func NewRoundRobin(s sd.Subscriber) Balancer { + return &roundRobin{ + s: s, + c: 0, + } +} + +type roundRobin struct { + s sd.Subscriber + c uint64 +} + +func (rr *roundRobin) Endpoint() (endpoint.Endpoint, error) { + endpoints, err := rr.s.Endpoints() + if err != nil { + return nil, err + } + if len(endpoints) <= 0 { + return nil, ErrNoEndpoints + } + old := atomic.AddUint64(&rr.c, 1) - 1 + idx := old % uint64(len(endpoints)) + return endpoints[idx], nil +} diff --git a/sd/lb/round_robin_test.go b/sd/lb/round_robin_test.go new file mode 100644 index 000000000..64a8baa45 --- /dev/null +++ b/sd/lb/round_robin_test.go @@ -0,0 +1,96 @@ +package lb + +import ( + "reflect" + "sync" + "sync/atomic" + "testing" + "time" + + "golang.org/x/net/context" + + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/sd" +) + +func TestRoundRobin(t *testing.T) { + var ( + counts = []int{0, 0, 0} + endpoints = []endpoint.Endpoint{ + func(context.Context, interface{}) (interface{}, error) { counts[0]++; return struct{}{}, nil }, + func(context.Context, interface{}) (interface{}, error) { counts[1]++; return struct{}{}, nil }, + func(context.Context, interface{}) (interface{}, error) { counts[2]++; return struct{}{}, nil }, + } + ) + + subscriber := sd.FixedSubscriber(endpoints) + balancer := NewRoundRobin(subscriber) + + for i, want := range [][]int{ + {1, 0, 0}, + {1, 1, 0}, + {1, 1, 1}, + {2, 1, 1}, + {2, 2, 1}, + {2, 2, 2}, + {3, 2, 2}, + } { + endpoint, err := balancer.Endpoint() + if err != nil { + t.Fatal(err) + } + endpoint(context.Background(), struct{}{}) + if have := counts; !reflect.DeepEqual(want, have) { + t.Fatalf("%d: want %v, have %v", i, want, have) + } + } +} + +func TestRoundRobinNoEndpoints(t *testing.T) { + subscriber := sd.FixedSubscriber{} + balancer := NewRoundRobin(subscriber) + _, err := balancer.Endpoint() + if want, have := ErrNoEndpoints, err; want != have { + t.Errorf("want %v, have %v", want, have) + } +} + +func TestRoundRobinNoRace(t *testing.T) { + balancer := NewRoundRobin(sd.FixedSubscriber([]endpoint.Endpoint{ + endpoint.Nop, + endpoint.Nop, + endpoint.Nop, + endpoint.Nop, + endpoint.Nop, + })) + + var ( + n = 100 + done = make(chan struct{}) + wg sync.WaitGroup + count uint64 + ) + + wg.Add(n) + + for i := 0; i < n; i++ { + go func() { + defer wg.Done() + for { + select { + case <-done: + return + default: + _, _ = balancer.Endpoint() + atomic.AddUint64(&count, 1) + } + } + }() + } + + time.Sleep(time.Second) + close(done) + wg.Wait() + + t.Logf("made %d calls", atomic.LoadUint64(&count)) +} diff --git a/sd/registrar.go b/sd/registrar.go new file mode 100644 index 000000000..49a0c9f21 --- /dev/null +++ b/sd/registrar.go @@ -0,0 +1,13 @@ +package sd + +// Registrar registers instance information to a service discovery system when +// an instance becomes alive and healthy, and deregisters that information when +// the service becomes unhealthy or goes away. +// +// Registrar implementations exist for various service discovery systems. Note +// that identifying instance information (e.g. host:port) must be given via the +// concrete constructor; this interface merely signals lifecycle changes. +type Registrar interface { + Register() + Deregister() +} diff --git a/sd/subscriber.go b/sd/subscriber.go new file mode 100644 index 000000000..8267b51bb --- /dev/null +++ b/sd/subscriber.go @@ -0,0 +1,11 @@ +package sd + +import "github.com/go-kit/kit/endpoint" + +// Subscriber listens to a service discovery system and yields a set of +// identical endpoints on demand. An error indicates a problem with connectivity +// to the service discovery system, or within the system itself; a subscriber +// may yield no endpoints without error. +type Subscriber interface { + Endpoints() ([]endpoint.Endpoint, error) +} diff --git a/loadbalancer/zk/client.go b/sd/zk/client.go similarity index 100% rename from loadbalancer/zk/client.go rename to sd/zk/client.go diff --git a/loadbalancer/zk/client_test.go b/sd/zk/client_test.go similarity index 88% rename from loadbalancer/zk/client_test.go rename to sd/zk/client_test.go index 2a7c52062..fbb2a5a17 100644 --- a/loadbalancer/zk/client_test.go +++ b/sd/zk/client_test.go @@ -107,15 +107,15 @@ func TestCreateParentNodes(t *testing.T) { t.Fatal("expected new Client, got nil") } - p, err := NewPublisher(c, "/validpath", newFactory(""), log.NewNopLogger()) + s, err := NewSubscriber(c, "/validpath", newFactory(""), log.NewNopLogger()) if err != stdzk.ErrNoServer { t.Errorf("unexpected error: %v", err) } - if p != nil { - t.Error("expected failed new Publisher") + if s != nil { + t.Error("expected failed new Subscriber") } - p, err = NewPublisher(c, "invalidpath", newFactory(""), log.NewNopLogger()) + s, err = NewSubscriber(c, "invalidpath", newFactory(""), log.NewNopLogger()) if err != stdzk.ErrInvalidPath { t.Errorf("unexpected error: %v", err) } @@ -131,12 +131,12 @@ func TestCreateParentNodes(t *testing.T) { t.Errorf("unexpected error: %v", err) } - p, err = NewPublisher(c, "/validpath", newFactory(""), log.NewNopLogger()) + s, err = NewSubscriber(c, "/validpath", newFactory(""), log.NewNopLogger()) if err != ErrClientClosed { t.Errorf("unexpected error: %v", err) } - if p != nil { - t.Error("expected failed new Publisher") + if s != nil { + t.Error("expected failed new Subscriber") } c, err = NewClient([]string{"localhost:65500"}, log.NewNopLogger(), Payload(payload)) @@ -147,11 +147,11 @@ func TestCreateParentNodes(t *testing.T) { t.Fatal("expected new Client, got nil") } - p, err = NewPublisher(c, "/validpath", newFactory(""), log.NewNopLogger()) + s, err = NewSubscriber(c, "/validpath", newFactory(""), log.NewNopLogger()) if err != stdzk.ErrNoServer { t.Errorf("unexpected error: %v", err) } - if p != nil { - t.Error("expected failed new Publisher") + if s != nil { + t.Error("expected failed new Subscriber") } } diff --git a/loadbalancer/zk/integration_test.go b/sd/zk/integration_test.go similarity index 85% rename from loadbalancer/zk/integration_test.go rename to sd/zk/integration_test.go index 96b1e2e67..0e67679b5 100644 --- a/loadbalancer/zk/integration_test.go +++ b/sd/zk/integration_test.go @@ -46,17 +46,17 @@ func TestCreateParentNodesOnServer(t *testing.T) { } defer c1.Stop() - p, err := NewPublisher(c1, path, newFactory(""), logger) + s, err := NewSubscriber(c1, path, newFactory(""), logger) if err != nil { - t.Fatalf("Unable to create Publisher: %v", err) + t.Fatalf("Unable to create Subscriber: %v", err) } - defer p.Stop() + defer s.Stop() - endpoints, err := p.Endpoints() + services, err := s.Services() if err != nil { t.Fatal(err) } - if want, have := 0, len(endpoints); want != have { + if want, have := 0, len(services); want != have { t.Errorf("want %d, have %d", want, have) } @@ -81,7 +81,7 @@ func TestCreateBadParentNodesOnServer(t *testing.T) { c, _ := NewClient(host, logger) defer c.Stop() - _, err := NewPublisher(c, "invalid/path", newFactory(""), logger) + _, err := NewSubscriber(c, "invalid/path", newFactory(""), logger) if want, have := stdzk.ErrInvalidPath, err; want != have { t.Errorf("want %v, have %v", want, have) @@ -93,7 +93,7 @@ func TestCredentials1(t *testing.T) { c, _ := NewClient(host, logger, ACL(acl), Credentials("user", "secret")) defer c.Stop() - _, err := NewPublisher(c, "/acl-issue-test", newFactory(""), logger) + _, err := NewSubscriber(c, "/acl-issue-test", newFactory(""), logger) if err != nil { t.Fatal(err) @@ -105,7 +105,7 @@ func TestCredentials2(t *testing.T) { c, _ := NewClient(host, logger, ACL(acl)) defer c.Stop() - _, err := NewPublisher(c, "/acl-issue-test", newFactory(""), logger) + _, err := NewSubscriber(c, "/acl-issue-test", newFactory(""), logger) if err != stdzk.ErrNoAuth { t.Errorf("want %v, have %v", stdzk.ErrNoAuth, err) @@ -116,7 +116,7 @@ func TestConnection(t *testing.T) { c, _ := NewClient(host, logger) c.Stop() - _, err := NewPublisher(c, "/acl-issue-test", newFactory(""), logger) + _, err := NewSubscriber(c, "/acl-issue-test", newFactory(""), logger) if err != ErrClientClosed { t.Errorf("want %v, have %v", ErrClientClosed, err) @@ -134,7 +134,7 @@ func TestGetEntriesOnServer(t *testing.T) { defer c1.Stop() c2, err := NewClient(host, logger) - p, err := NewPublisher(c2, path, newFactory(""), logger) + s, err := NewSubscriber(c2, path, newFactory(""), logger) if err != nil { t.Fatal(err) } @@ -162,11 +162,11 @@ func TestGetEntriesOnServer(t *testing.T) { time.Sleep(50 * time.Millisecond) - endpoints, err := p.Endpoints() + services, err := s.Services() if err != nil { t.Fatal(err) } - if want, have := 2, len(endpoints); want != have { + if want, have := 2, len(services); want != have { t.Errorf("want %d, have %d", want, have) } } diff --git a/loadbalancer/zk/logwrapper.go b/sd/zk/logwrapper.go similarity index 53% rename from loadbalancer/zk/logwrapper.go rename to sd/zk/logwrapper.go index e4fd3c45d..abb7b6dfd 100644 --- a/loadbalancer/zk/logwrapper.go +++ b/sd/zk/logwrapper.go @@ -8,18 +8,18 @@ import ( "github.com/go-kit/kit/log" ) -// wrapLogger wraps a go-kit logger so we can use it as the logging service for -// the ZooKeeper library (which expects a Printf method to be available) +// wrapLogger wraps a Go kit logger so we can use it as the logging service for +// the ZooKeeper library, which expects a Printf method to be available. type wrapLogger struct { log.Logger } -func (logger wrapLogger) Printf(str string, vars ...interface{}) { - logger.Log("msg", fmt.Sprintf(str, vars...)) +func (logger wrapLogger) Printf(format string, args ...interface{}) { + logger.Log("msg", fmt.Sprintf(format, args...)) } -// withLogger replaces the ZooKeeper library's default logging service for our -// own go-kit logger +// withLogger replaces the ZooKeeper library's default logging service with our +// own Go kit logger. func withLogger(logger log.Logger) func(c *zk.Conn) { return func(c *zk.Conn) { c.SetLogger(wrapLogger{logger}) diff --git a/sd/zk/subscriber.go b/sd/zk/subscriber.go new file mode 100644 index 000000000..b9c67db43 --- /dev/null +++ b/sd/zk/subscriber.go @@ -0,0 +1,86 @@ +package zk + +import ( + "github.com/samuel/go-zookeeper/zk" + + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/log" + "github.com/go-kit/kit/sd" + "github.com/go-kit/kit/sd/cache" +) + +// Subscriber yield endpoints stored in a certain ZooKeeper path. Any kind of +// change in that path is watched and will update the Subscriber endpoints. +type Subscriber struct { + client Client + path string + cache *cache.Cache + logger log.Logger + quitc chan struct{} +} + +var _ sd.Subscriber = &Subscriber{} + +// NewSubscriber returns a ZooKeeper subscriber. ZooKeeper will start watching +// the given path for changes and update the Subscriber endpoints. +func NewSubscriber(c Client, path string, factory sd.Factory, logger log.Logger) (*Subscriber, error) { + s := &Subscriber{ + client: c, + path: path, + cache: cache.New(factory, logger), + logger: logger, + quitc: make(chan struct{}), + } + + err := s.client.CreateParentNodes(s.path) + if err != nil { + return nil, err + } + + instances, eventc, err := s.client.GetEntries(s.path) + if err != nil { + logger.Log("path", s.path, "msg", "failed to retrieve entries", "err", err) + return nil, err + } + logger.Log("path", s.path, "instances", len(instances)) + s.cache.Update(instances) + + go s.loop(eventc) + + return s, nil +} + +func (s *Subscriber) loop(eventc <-chan zk.Event) { + var ( + instances []string + err error + ) + for { + select { + case <-eventc: + // We received a path update notification. Call GetEntries to + // retrieve child node data, and set a new watch, as ZK watches are + // one-time triggers. + instances, eventc, err = s.client.GetEntries(s.path) + if err != nil { + s.logger.Log("path", s.path, "msg", "failed to retrieve entries", "err", err) + continue + } + s.logger.Log("path", s.path, "instances", len(instances)) + s.cache.Update(instances) + + case <-s.quitc: + return + } + } +} + +// Endpoints implements the Subscriber interface. +func (s *Subscriber) Endpoints() ([]endpoint.Endpoint, error) { + return s.cache.Endpoints(), nil +} + +// Stop terminates the Subscriber. +func (s *Subscriber) Stop() { + close(s.quitc) +} diff --git a/sd/zk/subscriber_test.go b/sd/zk/subscriber_test.go new file mode 100644 index 000000000..79bdb84ed --- /dev/null +++ b/sd/zk/subscriber_test.go @@ -0,0 +1,117 @@ +package zk + +import ( + "testing" + "time" +) + +func TestSubscriber(t *testing.T) { + client := newFakeClient() + + s, err := NewSubscriber(client, path, newFactory(""), logger) + if err != nil { + t.Fatalf("failed to create new Subscriber: %v", err) + } + defer s.Stop() + + if _, err := s.Endpoints(); err != nil { + t.Fatal(err) + } +} + +func TestBadFactory(t *testing.T) { + client := newFakeClient() + + s, err := NewSubscriber(client, path, newFactory("kaboom"), logger) + if err != nil { + t.Fatalf("failed to create new Subscriber: %v", err) + } + defer s.Stop() + + // instance1 came online + client.AddService(path+"/instance1", "kaboom") + + // instance2 came online + client.AddService(path+"/instance2", "zookeeper_node_data") + + if err = asyncTest(100*time.Millisecond, 1, s); err != nil { + t.Error(err) + } +} + +func TestServiceUpdate(t *testing.T) { + client := newFakeClient() + + s, err := NewSubscriber(client, path, newFactory(""), logger) + if err != nil { + t.Fatalf("failed to create new Subscriber: %v", err) + } + defer s.Stop() + + endpoints, err := s.Endpoints() + if err != nil { + t.Fatal(err) + } + if want, have := 0, len(endpoints); want != have { + t.Errorf("want %d, have %d", want, have) + } + + // instance1 came online + client.AddService(path+"/instance1", "zookeeper_node_data1") + + // instance2 came online + client.AddService(path+"/instance2", "zookeeper_node_data2") + + // we should have 2 instances + if err = asyncTest(100*time.Millisecond, 2, s); err != nil { + t.Error(err) + } + + // TODO(pb): this bit is flaky + // + //// watch triggers an error... + //client.SendErrorOnWatch() + // + //// test if error was consumed + //if err = client.ErrorIsConsumedWithin(100 * time.Millisecond); err != nil { + // t.Error(err) + //} + + // instance3 came online + client.AddService(path+"/instance3", "zookeeper_node_data3") + + // we should have 3 instances + if err = asyncTest(100*time.Millisecond, 3, s); err != nil { + t.Error(err) + } + + // instance1 goes offline + client.RemoveService(path + "/instance1") + + // instance2 goes offline + client.RemoveService(path + "/instance2") + + // we should have 1 instance + if err = asyncTest(100*time.Millisecond, 1, s); err != nil { + t.Error(err) + } +} + +func TestBadSubscriberCreate(t *testing.T) { + client := newFakeClient() + client.SendErrorOnWatch() + s, err := NewSubscriber(client, path, newFactory(""), logger) + if err == nil { + t.Error("expected error on new Subscriber") + } + if s != nil { + t.Error("expected Subscriber not to be created") + } + s, err = NewSubscriber(client, "BadPath", newFactory(""), logger) + if err == nil { + t.Error("expected error on new Subscriber") + } + if s != nil { + t.Error("expected Subscriber not to be created") + } +} diff --git a/loadbalancer/zk/util_test.go b/sd/zk/util_test.go similarity index 75% rename from loadbalancer/zk/util_test.go rename to sd/zk/util_test.go index 078bcec82..2a4e1fe2f 100644 --- a/loadbalancer/zk/util_test.go +++ b/sd/zk/util_test.go @@ -11,8 +11,8 @@ import ( "golang.org/x/net/context" "github.com/go-kit/kit/endpoint" - "github.com/go-kit/kit/loadbalancer" "github.com/go-kit/kit/log" + "github.com/go-kit/kit/sd" ) var ( @@ -30,7 +30,7 @@ type fakeClient struct { func newFakeClient() *fakeClient { return &fakeClient{ - ch: make(chan zk.Event, 5), + ch: make(chan zk.Event, 1), responses: make(map[string]string), result: true, } @@ -38,7 +38,7 @@ func newFakeClient() *fakeClient { func (c *fakeClient) CreateParentNodes(path string) error { if path == "BadPath" { - return errors.New("Dummy Error") + return errors.New("dummy error") } return nil } @@ -48,7 +48,7 @@ func (c *fakeClient) GetEntries(path string) ([]string, <-chan zk.Event, error) defer c.mtx.Unlock() if c.result == false { c.result = true - return []string{}, c.ch, errors.New("Dummy Error") + return []string{}, c.ch, errors.New("dummy error") } responses := []string{} for _, data := range c.responses { @@ -78,12 +78,12 @@ func (c *fakeClient) SendErrorOnWatch() { c.ch <- zk.Event{} } -func (c *fakeClient) ErrorIsConsumed(t time.Duration) error { - timeout := time.After(t) +func (c *fakeClient) ErrorIsConsumedWithin(timeout time.Duration) error { + t := time.After(timeout) for { select { - case <-timeout: - return fmt.Errorf("expected error not consumed after timeout %s", t.String()) + case <-t: + return fmt.Errorf("expected error not consumed after timeout %s", timeout) default: c.mtx.Lock() if c.result == false { @@ -97,31 +97,30 @@ func (c *fakeClient) ErrorIsConsumed(t time.Duration) error { func (c *fakeClient) Stop() {} -func newFactory(fakeError string) loadbalancer.Factory { +func newFactory(fakeError string) sd.Factory { return func(instance string) (endpoint.Endpoint, io.Closer, error) { if fakeError == instance { return nil, nil, errors.New(fakeError) } - return e, nil, nil + return endpoint.Nop, nil, nil } } -func asyncTest(timeout time.Duration, want int, p *Publisher) (err error) { +func asyncTest(timeout time.Duration, want int, s *Subscriber) (err error) { var endpoints []endpoint.Endpoint - // want can never be -1 - have := -1 + have := -1 // want can never be <0 t := time.After(timeout) for { select { case <-t: - return fmt.Errorf("want %d, have %d after timeout %s", want, have, timeout.String()) + return fmt.Errorf("want %d, have %d (timeout %s)", want, have, timeout.String()) default: - endpoints, err = p.Endpoints() + endpoints, err = s.Endpoints() have = len(endpoints) if err != nil || want == have { return } - time.Sleep(time.Millisecond) + time.Sleep(timeout / 10) } } } diff --git a/transport/grpc/client.go b/transport/grpc/client.go index d68f89f42..b1c9af3fe 100644 --- a/transport/grpc/client.go +++ b/transport/grpc/client.go @@ -3,6 +3,7 @@ package grpc import ( "fmt" "reflect" + "strings" "golang.org/x/net/context" "google.golang.org/grpc" @@ -24,7 +25,7 @@ type Client struct { } // NewClient constructs a usable Client for a single remote endpoint. -// Pass an zero-value Protobuf message of the RPC response type as +// Pass an zero-value protobuf message of the RPC response type as // the grpcReply argument. func NewClient( cc *grpc.ClientConn, @@ -35,9 +36,12 @@ func NewClient( grpcReply interface{}, options ...ClientOption, ) *Client { + if strings.IndexByte(serviceName, '.') == -1 { + serviceName = "pb." + serviceName + } c := &Client{ client: cc, - method: fmt.Sprintf("/pb.%s/%s", serviceName, method), + method: fmt.Sprintf("/%s/%s", serviceName, method), enc: enc, dec: dec, // We are using reflect.Indirect here to allow both reply structs and diff --git a/transport/grpc/server.go b/transport/grpc/server.go index cbbb0772b..acb542a92 100644 --- a/transport/grpc/server.go +++ b/transport/grpc/server.go @@ -9,9 +9,10 @@ import ( ) // Handler which should be called from the grpc binding of the service -// implementation. +// implementation. The incoming request parameter, and returned response +// parameter, are both gRPC types, not user-domain. type Handler interface { - ServeGRPC(context.Context, interface{}) (context.Context, interface{}, error) + ServeGRPC(ctx context.Context, request interface{}) (context.Context, interface{}, error) } // Server wraps an endpoint and implements grpc.Handler. @@ -25,8 +26,11 @@ type Server struct { logger log.Logger } -// NewServer constructs a new server, which implements grpc.Server and wraps -// the provided endpoint. +// NewServer constructs a new server, which implements wraps the provided +// endpoint and implements the Handler interface. Consumers should write +// bindings that adapt the concrete gRPC methods from their compiled protobuf +// definitions to individual handlers. Request and response objects are from the +// caller business domain, not gRPC request and reply types. func NewServer( ctx context.Context, e endpoint.Endpoint, @@ -68,12 +72,12 @@ func ServerErrorLogger(logger log.Logger) ServerOption { return func(s *Server) { s.logger = logger } } -// ServeGRPC implements grpc.Handler -func (s Server) ServeGRPC(grpcCtx context.Context, r interface{}) (context.Context, interface{}, error) { +// ServeGRPC implements the Handler interface. +func (s Server) ServeGRPC(grpcCtx context.Context, req interface{}) (context.Context, interface{}, error) { ctx, cancel := context.WithCancel(s.ctx) defer cancel() - // retrieve gRPC metadata + // Retrieve gRPC metadata. md, ok := metadata.FromContext(grpcCtx) if !ok { md = metadata.MD{} @@ -83,10 +87,10 @@ func (s Server) ServeGRPC(grpcCtx context.Context, r interface{}) (context.Conte ctx = f(ctx, &md) } - // store potentially updated metadata in the gRPC context + // Store potentially updated metadata in the gRPC context. grpcCtx = metadata.NewContext(grpcCtx, md) - request, err := s.dec(grpcCtx, r) + request, err := s.dec(grpcCtx, req) if err != nil { s.logger.Log("err", err) return grpcCtx, nil, BadRequestError{err} @@ -102,7 +106,7 @@ func (s Server) ServeGRPC(grpcCtx context.Context, r interface{}) (context.Conte f(ctx, &md) } - // store potentially updated metadata in the gRPC context + // Store potentially updated metadata in the gRPC context. grpcCtx = metadata.NewContext(grpcCtx, md) grpcResp, err := s.enc(grpcCtx, response) @@ -110,6 +114,7 @@ func (s Server) ServeGRPC(grpcCtx context.Context, r interface{}) (context.Conte s.logger.Log("err", err) return grpcCtx, nil, err } + return grpcCtx, grpcResp, nil } diff --git a/transport/http/client.go b/transport/http/client.go index 39475c946..66523193c 100644 --- a/transport/http/client.go +++ b/transport/http/client.go @@ -21,7 +21,7 @@ type Client struct { bufferedStream bool } -// NewClient constructs a usable Client for a single remote endpoint. +// NewClient constructs a usable Client for a single remote method. func NewClient( method string, tgt *url.URL, @@ -65,8 +65,7 @@ func SetBufferedStream(buffered bool) ClientOption { return func(c *Client) { c.bufferedStream = buffered } } -// Endpoint returns a usable endpoint that will invoke the RPC specified by -// the client. +// Endpoint returns a usable endpoint that invokes the remote endpoint. func (c Client) Endpoint() endpoint.Endpoint { return func(ctx context.Context, request interface{}) (interface{}, error) { ctx, cancel := context.WithCancel(ctx)