Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions helper/http2/http2.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
package http2

import (
"context"
"crypto/tls"
"fmt"
"log"
Expand Down Expand Up @@ -110,18 +111,27 @@ func (srv *Server) Serve(ln net.Listener) error {

delay = 0

go func() {
if err := srv.serveConn(conn); err != nil {
baseCtx := context.Background()
if srv.h1.BaseContext != nil {
baseCtx = srv.h1.BaseContext(ln)
}

go func(conn net.Conn, baseCtx context.Context) {
if err := srv.serveConn(conn, baseCtx); err != nil {
srv.errorLog().Printf("listener %q: %v", ln.Addr(), err)
}
}()
}(conn, baseCtx)
}
}

func (srv *Server) serveConn(conn net.Conn) error {
func (srv *Server) serveConn(conn net.Conn, baseCtx context.Context) error {
var proto string
switch conn := conn.(type) {
case *tls.Conn:
if err := conn.Handshake(); err != nil {
conn.Close()
return err
}
proto = conn.ConnectionState().NegotiatedProtocol
case *proxyproto.Conn:
if proxyHeader := conn.ProxyHeader(); proxyHeader != nil {
Expand All @@ -143,7 +153,16 @@ func (srv *Server) serveConn(conn net.Conn) error {
switch proto {
case http2.NextProtoTLS, "h2c":
defer conn.Close()
opts := http2.ServeConnOpts{Handler: srv.h1.Handler}

ctx := baseCtx
// We don't check if srv.h1.ConnContext is nil so http.Server works the same
// with or without this middleware.
// For more info, see https://github.com/pires/go-proxyproto/pull/140/changes#r2725568706.
if connCtx := srv.h1.ConnContext(ctx, conn); connCtx != nil {
ctx = connCtx
}

opts := http2.ServeConnOpts{Context: ctx, BaseConfig: srv.h1}
srv.h2.ServeConn(conn, &opts)
return nil
case "", "http/1.0", "http/1.1":
Expand Down
140 changes: 140 additions & 0 deletions helper/http2/http2_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
package http2_test

import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"errors"
"log"
"math/big"
"net"
"net/http"
"testing"
"time"

"github.com/pires/go-proxyproto"
h2proxy "github.com/pires/go-proxyproto/helper/http2"
Expand All @@ -32,6 +40,13 @@ func ExampleServer() {
}
}

type contextKey string

const (
connContextKey = contextKey("conn")
baseContextKey = contextKey("base")
)

func TestServer_h1(t *testing.T) {
addr, server := newTestServer(t)
defer server.Close()
Expand Down Expand Up @@ -86,6 +101,36 @@ func TestServer_h2(t *testing.T) {
resp.Body.Close()
}

func TestServer_h2_tls(t *testing.T) {
addr, server := newTLSTestServer(t)
defer server.Close()

conn, err := tls.Dial("tcp", addr, &tls.Config{
InsecureSkipVerify: true,
NextProtos: []string{http2.NextProtoTLS},
})
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
defer conn.Close()

h2Conn, err := new(http2.Transport).NewClientConn(conn)
if err != nil {
t.Fatalf("failed to create HTTP connection: %v", err)
}

req, err := http.NewRequest(http.MethodGet, "https://"+addr, nil)
if err != nil {
t.Fatalf("failed to create HTTP request: %v", err)
}

resp, err := h2Conn.RoundTrip(req)
if err != nil {
t.Fatalf("failed to perform HTTP request: %v", err)
}
resp.Body.Close()
}

func newTestServer(t *testing.T) (addr string, server *http.Server) {
ln, err := net.Listen("tcp", "localhost:0")
if err != nil {
Expand All @@ -94,7 +139,19 @@ func newTestServer(t *testing.T) (addr string, server *http.Server) {

server = &http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if v := r.Context().Value(connContextKey); v == nil {
t.Errorf("http.Request.Context missing connContextKey")
}
if v := r.Context().Value(baseContextKey); v == nil {
t.Errorf("http.Request.Context missing baseContextKey")
}
}),
BaseContext: func(_ net.Listener) context.Context {
return context.WithValue(context.Background(), baseContextKey, struct{}{})
},
ConnContext: func(ctx context.Context, conn net.Conn) context.Context {
return context.WithValue(ctx, connContextKey, struct{}{})
},
}

h2Server := h2proxy.NewServer(server, nil)
Expand All @@ -112,3 +169,86 @@ func newTestServer(t *testing.T) (addr string, server *http.Server) {

return ln.Addr().String(), server
}

func newTLSTestServer(t *testing.T) (addr string, server *http.Server) {
ln, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("failed to listen: %v", err)
}

server = &http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if v := r.Context().Value(connContextKey); v == nil {
t.Errorf("http.Request.Context missing connContextKey")
}
if v := r.Context().Value(baseContextKey); v == nil {
t.Errorf("http.Request.Context missing baseContextKey")
}
}),
BaseContext: func(_ net.Listener) context.Context {
return context.WithValue(context.Background(), baseContextKey, struct{}{})
},
ConnContext: func(ctx context.Context, conn net.Conn) context.Context {
return context.WithValue(ctx, connContextKey, struct{}{})
},
}

tlsLn := tls.NewListener(ln, testTLSConfig(t))
h2Server := h2proxy.NewServer(server, nil)
done := make(chan error, 1)
go func() {
done <- h2Server.Serve(tlsLn)
}()

t.Cleanup(func() {
err := <-done
if err != nil && !errors.Is(err, net.ErrClosed) {
t.Fatalf("failed to serve: %v", err)
}
})

return ln.Addr().String(), server
}

func testTLSConfig(t *testing.T) *tls.Config {
t.Helper()

key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("failed to generate key: %v", err)
}

serial, err := rand.Int(rand.Reader, big.NewInt(1<<62))
if err != nil {
t.Fatalf("failed to generate serial: %v", err)
}

template := x509.Certificate{
SerialNumber: serial,
Subject: pkix.Name{
CommonName: "localhost",
},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
DNSNames: []string{"localhost"},
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
}

der, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
if err != nil {
t.Fatalf("failed to create cert: %v", err)
}

cert := tls.Certificate{
Certificate: [][]byte{der},
PrivateKey: key,
}

return &tls.Config{
Certificates: []tls.Certificate{cert},
NextProtos: []string{http2.NextProtoTLS},
}
}
Loading