Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions http3/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@ type ClientConn struct {
var _ http.RoundTripper = &ClientConn{}

func newClientConn(
conn *quic.Conn,
conn QUICConn,
enableDatagrams bool,
additionalSettings map[uint64]uint64,
streamHijacker func(FrameType, quic.ConnectionTracingID, *quic.Stream, error) (hijacked bool, err error),
uniStreamHijacker func(StreamType, quic.ConnectionTracingID, *quic.ReceiveStream, error) (hijacked bool),
streamHijacker func(FrameType, quic.ConnectionTracingID, QUICStream, error) (hijacked bool, err error),
uniStreamHijacker func(StreamType, quic.ConnectionTracingID, QUICReceiveStream, error) (hijacked bool),
maxResponseHeaderBytes int64,
disableCompression bool,
logger *slog.Logger,
Expand Down Expand Up @@ -141,7 +141,7 @@ func (c *ClientConn) setupConn() error {
return err
}

func (c *ClientConn) handleBidirectionalStreams(streamHijacker func(FrameType, quic.ConnectionTracingID, *quic.Stream, error) (hijacked bool, err error)) {
func (c *ClientConn) handleBidirectionalStreams(streamHijacker func(FrameType, quic.ConnectionTracingID, QUICStream, error) (hijacked bool, err error)) {
for {
str, err := c.conn.conn.AcceptStream(context.Background())
if err != nil {
Expand Down
18 changes: 9 additions & 9 deletions http3/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ var errGoAway = errors.New("connection in graceful shutdown")
// It has all methods from the quic.Conn expect for AcceptStream, AcceptUniStream,
// SendDatagram and ReceiveDatagram.
type Conn struct {
conn *quic.Conn
conn QUICConn

ctx context.Context

Expand All @@ -53,7 +53,7 @@ type Conn struct {

func newConnection(
ctx context.Context,
quicConn *quic.Conn,
quicConn QUICConn,
enableDatagrams bool,
perspective protocol.Perspective,
logger *slog.Logger,
Expand All @@ -78,19 +78,19 @@ func newConnection(
return c
}

func (c *Conn) OpenStream() (*quic.Stream, error) {
func (c *Conn) OpenStream() (QUICStream, error) {
return c.conn.OpenStream()
}

func (c *Conn) OpenStreamSync(ctx context.Context) (*quic.Stream, error) {
func (c *Conn) OpenStreamSync(ctx context.Context) (QUICStream, error) {
return c.conn.OpenStreamSync(ctx)
}

func (c *Conn) OpenUniStream() (*quic.SendStream, error) {
func (c *Conn) OpenUniStream() (QUICSendStream, error) {
return c.conn.OpenUniStream()
}

func (c *Conn) OpenUniStreamSync(ctx context.Context) (*quic.SendStream, error) {
func (c *Conn) OpenUniStreamSync(ctx context.Context) (QUICSendStream, error) {
return c.conn.OpenUniStreamSync(ctx)
}

Expand Down Expand Up @@ -224,7 +224,7 @@ func (c *Conn) CloseWithError(code quic.ApplicationErrorCode, msg string) error
return c.conn.CloseWithError(code, msg)
}

func (c *Conn) handleUnidirectionalStreams(hijack func(StreamType, quic.ConnectionTracingID, *quic.ReceiveStream, error) (hijacked bool)) {
func (c *Conn) handleUnidirectionalStreams(hijack func(StreamType, quic.ConnectionTracingID, QUICReceiveStream, error) (hijacked bool)) {
var (
rcvdControlStr atomic.Bool
rcvdQPACKEncoderStr atomic.Bool
Expand All @@ -240,7 +240,7 @@ func (c *Conn) handleUnidirectionalStreams(hijack func(StreamType, quic.Connecti
return
}

go func(str *quic.ReceiveStream) {
go func(str QUICReceiveStream) {
streamType, err := quicvarint.Read(quicvarint.NewReader(str))
if err != nil {
id := c.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID)
Expand Down Expand Up @@ -301,7 +301,7 @@ func (c *Conn) handleUnidirectionalStreams(hijack func(StreamType, quic.Connecti
}
}

func (c *Conn) handleControlStream(str *quic.ReceiveStream) {
func (c *Conn) handleControlStream(str QUICReceiveStream) {
fp := &frameParser{closeConn: c.conn.CloseWithError, r: str}
f, err := fp.ParseNext()
if err != nil {
Expand Down
102 changes: 102 additions & 0 deletions http3/quic.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package http3

import (
"context"
"io"
"net"
"time"

"github.com/quic-go/quic-go"
)

type QUICSendStream interface {
StreamID() quic.StreamID
io.WriteCloser
CancelWrite(quic.StreamErrorCode)
SetWriteDeadline(time.Time) error
}

var _ QUICSendStream = &quic.SendStream{}

type QUICReceiveStream interface {
StreamID() quic.StreamID
io.Reader
CancelRead(quic.StreamErrorCode)
SetReadDeadline(time.Time) error
}

var _ QUICReceiveStream = &quic.ReceiveStream{}

type QUICStream interface {
QUICSendStream
QUICReceiveStream
Context() context.Context
SetDeadline(time.Time) error
}

var _ QUICStream = &quic.Stream{}

type QUICConn interface {
OpenStream() (QUICStream, error)
OpenStreamSync(context.Context) (QUICStream, error)
OpenUniStream() (QUICSendStream, error)
OpenUniStreamSync(context.Context) (QUICSendStream, error)
AcceptStream(context.Context) (QUICStream, error)
AcceptUniStream(context.Context) (QUICReceiveStream, error)

Context() context.Context
LocalAddr() net.Addr
RemoteAddr() net.Addr
CloseWithError(quic.ApplicationErrorCode, string) error
ConnectionState() quic.ConnectionState
HandshakeComplete() <-chan struct{}
SendDatagram([]byte) error
ReceiveDatagram(context.Context) ([]byte, error)
}

type connAdapter struct {
*quic.Conn
}

func (c *connAdapter) OpenStream() (QUICStream, error) {
return c.OpenStream()
}

func (c *connAdapter) OpenStreamSync(ctx context.Context) (QUICStream, error) {
return c.OpenStreamSync(ctx)
}

func (c *connAdapter) OpenUniStream() (QUICSendStream, error) {
return c.OpenUniStream()
}

func (c *connAdapter) OpenUniStreamSync(ctx context.Context) (QUICSendStream, error) {
return c.OpenUniStreamSync(ctx)
}

func (c *connAdapter) AcceptStream(ctx context.Context) (QUICStream, error) {
return c.AcceptStream(ctx)
}

func (c *connAdapter) AcceptUniStream(ctx context.Context) (QUICReceiveStream, error) {
return c.AcceptUniStream(ctx)
}

// A QUICListener listens for incoming QUIC connections.
type QUICListener interface {
Accept(context.Context) (QUICConn, error)
Addr() net.Addr
io.Closer
}

type quicListenerAdapter struct {
*quic.EarlyListener
}

func (l *quicListenerAdapter) Accept(ctx context.Context) (QUICConn, error) {
conn, err := l.EarlyListener.Accept(ctx)
if err != nil {
return nil, err
}
return &connAdapter{conn}, nil
}
58 changes: 31 additions & 27 deletions http3/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,6 @@ const (
streamTypeQPACKDecoderStream = 3
)

// A QUICListener listens for incoming QUIC connections.
type QUICListener interface {
Accept(context.Context) (*quic.Conn, error)
Addr() net.Addr
io.Closer
}

var _ QUICListener = &quic.EarlyListener{}

// ConfigureTLSConfig creates a new tls.Config which can be used
// to create a quic.Listener meant for serving HTTP/3.
func ConfigureTLSConfig(tlsConf *tls.Config) *tls.Config {
Expand Down Expand Up @@ -156,12 +147,12 @@ type Server struct {
// Callers can either ignore the frame and return control of the stream back to HTTP/3
// (by returning hijacked false).
// Alternatively, callers can take over the QUIC stream (by returning hijacked true).
StreamHijacker func(FrameType, quic.ConnectionTracingID, *quic.Stream, error) (hijacked bool, err error)
StreamHijacker func(FrameType, quic.ConnectionTracingID, QUICStream, error) (hijacked bool, err error)

// UniStreamHijacker, when set, is called for unknown unidirectional stream of unknown stream type.
// If parsing the stream type fails, the error is passed to the callback.
// In that case, the stream type will not be set.
UniStreamHijacker func(StreamType, quic.ConnectionTracingID, *quic.ReceiveStream, error) (hijacked bool)
UniStreamHijacker func(StreamType, quic.ConnectionTracingID, QUICReceiveStream, error) (hijacked bool)

// IdleTimeout specifies how long until idle clients connection should be
// closed. Idle refers only to the HTTP/3 layer, activity at the QUIC layer
Expand All @@ -171,7 +162,7 @@ type Server struct {

// ConnContext optionally specifies a function that modifies the context used for a new connection c.
// The provided ctx has a ServerContextKey value.
ConnContext func(ctx context.Context, c *quic.Conn) context.Context
ConnContext func(ctx context.Context, c QUICConn) context.Context

Logger *slog.Logger

Expand Down Expand Up @@ -252,8 +243,7 @@ func (s *Server) decreaseConnCount() {
}
}

// ServeQUICConn serves a single QUIC connection.
func (s *Server) ServeQUICConn(conn *quic.Conn) error {
func (s *Server) ServeQUICConnInterface(conn QUICConn) error {
s.mutex.Lock()
if s.closed {
s.mutex.Unlock()
Expand All @@ -269,12 +259,12 @@ func (s *Server) ServeQUICConn(conn *quic.Conn) error {
return s.handleConn(conn)
}

// ServeListener serves an existing QUIC listener.
// Make sure you use http3.ConfigureTLSConfig to configure a tls.Config
// and use it to construct a http3-friendly QUIC listener.
// Closing the server does close the listener.
// ServeListener always returns a non-nil error. After Shutdown or Close, the returned error is http.ErrServerClosed.
func (s *Server) ServeListener(ln QUICListener) error {
// ServeQUICConn serves a single QUIC connection.
func (s *Server) ServeQUICConn(conn *quic.Conn) error {
return s.ServeQUICConnInterface(&connAdapter{conn})
}

func (s *Server) ServeListenerInterface(ln QUICListener) error {
s.mutex.Lock()
if err := s.addListener(&ln, false); err != nil {
s.mutex.Unlock()
Expand All @@ -286,6 +276,15 @@ func (s *Server) ServeListener(ln QUICListener) error {
return s.serveListener(ln)
}

// ServeListener serves an existing QUIC listener.
// Make sure you use http3.ConfigureTLSConfig to configure a tls.Config
// and use it to construct a http3-friendly QUIC listener.
// Closing the server does close the listener.
// ServeListener always returns a non-nil error. After Shutdown or Close, the returned error is http.ErrServerClosed.
func (s *Server) ServeListener(ln *quic.EarlyListener) error {
return s.ServeListenerInterface(&quicListenerAdapter{ln})
}

func (s *Server) serveListener(ln QUICListener) error {
for {
conn, err := ln.Accept(s.graceCtx)
Expand Down Expand Up @@ -334,19 +333,24 @@ func (s *Server) setupListenerForConn(tlsConf *tls.Config, conn net.PacketConn)
}

var ln QUICListener
var err error
if conn == nil {
addr := s.Addr
if addr == "" {
addr = ":https"
}
ln, err = quic.ListenAddrEarly(addr, baseConf, quicConf)
l, err := quic.ListenAddrEarly(addr, baseConf, quicConf)
if err != nil {
return nil, err
}
ln = &quicListenerAdapter{l}
} else {
ln, err = quic.ListenEarly(conn, baseConf, quicConf)
}
if err != nil {
return nil, err
l, err := quic.ListenEarly(conn, baseConf, quicConf)
if err != nil {
return nil, err
}
ln = &quicListenerAdapter{l}
}

if err := s.addListener(&ln, true); err != nil {
return nil, err
}
Expand Down Expand Up @@ -437,7 +441,7 @@ func (s *Server) removeListener(l *QUICListener) {

// handleConn handles the HTTP/3 exchange on a QUIC connection.
// It blocks until all HTTP handlers for all streams have returned.
func (s *Server) handleConn(conn *quic.Conn) error {
func (s *Server) handleConn(conn QUICConn) error {
// open the control stream and send a SETTINGS frame, it's also used to send a GOAWAY frame later
// when the server is gracefully closed
ctrlStr, err := conn.OpenUniStream()
Expand Down
4 changes: 2 additions & 2 deletions http3/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -490,12 +490,12 @@ func testServerHijackBidirectionalStream(t *testing.T, bidirectional bool, doHij
hijackChan := make(chan hijackCall, 1)
testDone := make(chan struct{})
s := &Server{
StreamHijacker: func(ft FrameType, connTracingID quic.ConnectionTracingID, _ *quic.Stream, e error) (hijacked bool, err error) {
StreamHijacker: func(ft FrameType, connTracingID quic.ConnectionTracingID, _ QUICStream, e error) (hijacked bool, err error) {
defer close(testDone)
hijackChan <- hijackCall{ft: ft, connTracingID: connTracingID, e: e}
return doHijack, hijackErr
},
UniStreamHijacker: func(st StreamType, connTracingID quic.ConnectionTracingID, _ *quic.ReceiveStream, err error) bool {
UniStreamHijacker: func(st StreamType, connTracingID quic.ConnectionTracingID, _ QUICReceiveStream, err error) bool {
defer close(testDone)
hijackChan <- hijackCall{st: st, connTracingID: connTracingID, e: err}
return doHijack
Expand Down
27 changes: 24 additions & 3 deletions http3/state_tracking_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"os"
"sync"
"time"

"github.com/quic-go/quic-go"
)
Expand All @@ -19,7 +20,7 @@ const streamDatagramQueueLen = 32
// parent connection, this is done through the streamClearer interface when
// both the send and receive sides are closed
type stateTrackingStream struct {
*quic.Stream
Stream QUICStream

sendDatagram func([]byte) error
hasData chan struct{}
Expand All @@ -38,7 +39,7 @@ type streamClearer interface {
clearStream(quic.StreamID)
}

func newStateTrackingStream(s *quic.Stream, clearer streamClearer, sendDatagram func([]byte) error) *stateTrackingStream {
func newStateTrackingStream(s QUICStream, clearer streamClearer, sendDatagram func([]byte) error) *stateTrackingStream {
t := &stateTrackingStream{
Stream: s,
clearer: clearer,
Expand Down Expand Up @@ -168,6 +169,26 @@ start:
goto start
}

func (s *stateTrackingStream) QUICStream() *quic.Stream {
func (s *stateTrackingStream) QUICStream() QUICStream {
return s.Stream
}

func (s *stateTrackingStream) Context() context.Context {
return s.Stream.Context()
}

func (s *stateTrackingStream) StreamID() quic.StreamID {
return s.Stream.StreamID()
}

func (s *stateTrackingStream) SetDeadline(t time.Time) error {
return s.Stream.SetDeadline(t)
}

func (s *stateTrackingStream) SetReadDeadline(t time.Time) error {
return s.Stream.SetReadDeadline(t)
}

func (s *stateTrackingStream) SetWriteDeadline(t time.Time) error {
return s.Stream.SetWriteDeadline(t)
}
2 changes: 1 addition & 1 deletion http3/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ type datagramStream interface {
SendDatagram(b []byte) error
ReceiveDatagram(ctx context.Context) ([]byte, error)

QUICStream() *quic.Stream
QUICStream() QUICStream
}

// A Stream is an HTTP/3 stream.
Expand Down
Loading
Loading