Skip to content

Commit 96b9868

Browse files
committed
http2: avoid empty ALPN on TLS connections
1 parent 9cd9cbd commit 96b9868

File tree

2 files changed

+124
-0
lines changed

2 files changed

+124
-0
lines changed

helper/http2/http2.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,10 @@ func (srv *Server) serveConn(conn net.Conn, baseCtx context.Context) error {
128128
var proto string
129129
switch conn := conn.(type) {
130130
case *tls.Conn:
131+
if err := conn.Handshake(); err != nil {
132+
conn.Close()
133+
return err
134+
}
131135
proto = conn.ConnectionState().NegotiatedProtocol
132136
case *proxyproto.Conn:
133137
if proxyHeader := conn.ProxyHeader(); proxyHeader != nil {

helper/http2/http2_test.go

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,18 @@ package http2_test
22

33
import (
44
"context"
5+
"crypto/rand"
6+
"crypto/rsa"
7+
"crypto/tls"
8+
"crypto/x509"
9+
"crypto/x509/pkix"
510
"errors"
611
"log"
12+
"math/big"
713
"net"
814
"net/http"
915
"testing"
16+
"time"
1017

1118
"github.com/pires/go-proxyproto"
1219
h2proxy "github.com/pires/go-proxyproto/helper/http2"
@@ -94,6 +101,36 @@ func TestServer_h2(t *testing.T) {
94101
resp.Body.Close()
95102
}
96103

104+
func TestServer_h2_tls(t *testing.T) {
105+
addr, server := newTLSTestServer(t)
106+
defer server.Close()
107+
108+
conn, err := tls.Dial("tcp", addr, &tls.Config{
109+
InsecureSkipVerify: true,
110+
NextProtos: []string{http2.NextProtoTLS},
111+
})
112+
if err != nil {
113+
t.Fatalf("failed to dial: %v", err)
114+
}
115+
defer conn.Close()
116+
117+
h2Conn, err := new(http2.Transport).NewClientConn(conn)
118+
if err != nil {
119+
t.Fatalf("failed to create HTTP connection: %v", err)
120+
}
121+
122+
req, err := http.NewRequest(http.MethodGet, "https://"+addr, nil)
123+
if err != nil {
124+
t.Fatalf("failed to create HTTP request: %v", err)
125+
}
126+
127+
resp, err := h2Conn.RoundTrip(req)
128+
if err != nil {
129+
t.Fatalf("failed to perform HTTP request: %v", err)
130+
}
131+
resp.Body.Close()
132+
}
133+
97134
func newTestServer(t *testing.T) (addr string, server *http.Server) {
98135
ln, err := net.Listen("tcp", "localhost:0")
99136
if err != nil {
@@ -132,3 +169,86 @@ func newTestServer(t *testing.T) (addr string, server *http.Server) {
132169

133170
return ln.Addr().String(), server
134171
}
172+
173+
func newTLSTestServer(t *testing.T) (addr string, server *http.Server) {
174+
ln, err := net.Listen("tcp", "localhost:0")
175+
if err != nil {
176+
t.Fatalf("failed to listen: %v", err)
177+
}
178+
179+
server = &http.Server{
180+
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
181+
if v := r.Context().Value(connContextKey); v == nil {
182+
t.Errorf("http.Request.Context missing connContextKey")
183+
}
184+
if v := r.Context().Value(baseContextKey); v == nil {
185+
t.Errorf("http.Request.Context missing baseContextKey")
186+
}
187+
}),
188+
BaseContext: func(_ net.Listener) context.Context {
189+
return context.WithValue(context.Background(), baseContextKey, struct{}{})
190+
},
191+
ConnContext: func(ctx context.Context, conn net.Conn) context.Context {
192+
return context.WithValue(ctx, connContextKey, struct{}{})
193+
},
194+
}
195+
196+
tlsLn := tls.NewListener(ln, testTLSConfig(t))
197+
h2Server := h2proxy.NewServer(server, nil)
198+
done := make(chan error, 1)
199+
go func() {
200+
done <- h2Server.Serve(tlsLn)
201+
}()
202+
203+
t.Cleanup(func() {
204+
err := <-done
205+
if err != nil && !errors.Is(err, net.ErrClosed) {
206+
t.Fatalf("failed to serve: %v", err)
207+
}
208+
})
209+
210+
return ln.Addr().String(), server
211+
}
212+
213+
func testTLSConfig(t *testing.T) *tls.Config {
214+
t.Helper()
215+
216+
key, err := rsa.GenerateKey(rand.Reader, 2048)
217+
if err != nil {
218+
t.Fatalf("failed to generate key: %v", err)
219+
}
220+
221+
serial, err := rand.Int(rand.Reader, big.NewInt(1<<62))
222+
if err != nil {
223+
t.Fatalf("failed to generate serial: %v", err)
224+
}
225+
226+
template := x509.Certificate{
227+
SerialNumber: serial,
228+
Subject: pkix.Name{
229+
CommonName: "localhost",
230+
},
231+
NotBefore: time.Now().Add(-time.Hour),
232+
NotAfter: time.Now().Add(time.Hour),
233+
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
234+
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
235+
BasicConstraintsValid: true,
236+
DNSNames: []string{"localhost"},
237+
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
238+
}
239+
240+
der, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
241+
if err != nil {
242+
t.Fatalf("failed to create cert: %v", err)
243+
}
244+
245+
cert := tls.Certificate{
246+
Certificate: [][]byte{der},
247+
PrivateKey: key,
248+
}
249+
250+
return &tls.Config{
251+
Certificates: []tls.Certificate{cert},
252+
NextProtos: []string{http2.NextProtoTLS},
253+
}
254+
}

0 commit comments

Comments
 (0)