diff --git a/helper/http2/http2.go b/helper/http2/http2.go index 24b59b0..dd7788a 100644 --- a/helper/http2/http2.go +++ b/helper/http2/http2.go @@ -2,6 +2,7 @@ package http2 import ( + "context" "crypto/tls" "fmt" "log" @@ -143,7 +144,13 @@ func (srv *Server) serveConn(conn net.Conn) error { switch proto { case http2.NextProtoTLS, "h2c": defer conn.Close() - opts := http2.ServeConnOpts{Handler: srv.h1.Handler} + + ctx := context.Background() + if srv.h1.ConnContext != nil { + ctx = srv.h1.ConnContext(ctx, conn) + } + + opts := http2.ServeConnOpts{Context: ctx, BaseConfig: srv.h1} srv.h2.ServeConn(conn, &opts) return nil case "", "http/1.0", "http/1.1": diff --git a/helper/http2/http2_test.go b/helper/http2/http2_test.go index 054f12d..aa8719c 100644 --- a/helper/http2/http2_test.go +++ b/helper/http2/http2_test.go @@ -1,6 +1,7 @@ package http2_test import ( + "context" "errors" "log" "net" @@ -32,6 +33,10 @@ func ExampleServer() { } } +type contextKey string + +const connContextKey = contextKey("conn") + func TestServer_h1(t *testing.T) { addr, server := newTestServer(t) defer server.Close() @@ -94,7 +99,13 @@ func newTestServer(t *testing.T) (addr string, server *http.Server) { server = &http.Server{ Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if v := r.Context().Value(connContextKey); v == nil { + t.Errorf("http.Request.Context missing connContextKey") + } }), + ConnContext: func(ctx context.Context, conn net.Conn) context.Context { + return context.WithValue(ctx, connContextKey, struct{}{}) + }, } h2Server := h2proxy.NewServer(server, nil)