Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 39 additions & 11 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,24 +43,33 @@ var (
// Only one of Policy or ConnPolicy should be provided. If both are provided then
// a panic would occur during accept.
type Listener struct {
// Listener is the underlying listener.
Listener net.Listener
// Deprecated: use ConnPolicyFunc instead. This will be removed in future release.
Policy PolicyFunc
ConnPolicy ConnPolicyFunc
ValidateHeader Validator
Policy PolicyFunc
// ConnPolicy is the policy function for accepted connections.
ConnPolicy ConnPolicyFunc
// ValidateHeader is the validator function for the proxy header.
ValidateHeader Validator
// ReadHeaderTimeout is the timeout for reading the proxy header.
ReadHeaderTimeout time.Duration
// ReadBufferSize is the read buffer size for accepted connections. When > 0,
// each accepted connection uses this size for proxy header detection; 0 means default.
ReadBufferSize int
}

// Conn is used to wrap and underlying connection which
// may be speaking the Proxy Protocol. If it is, the RemoteAddr() will
// return the address of the client instead of the proxy address. Each connection
// will have its own readHeaderTimeout and readDeadline set by the Accept() call.
type Conn struct {
readDeadline atomic.Value // time.Time
once sync.Once
readErr error
conn net.Conn
bufReader *bufio.Reader
readDeadline atomic.Value // time.Time
once sync.Once
readErr error
conn net.Conn
bufReader *bufio.Reader
// bufferSize is set when the client overrides via WithBufferSize; nil means use default.
bufferSize *int
header *Header
ProxyHeaderPolicy Policy
Validate Validator
Expand Down Expand Up @@ -89,6 +98,22 @@ func SetReadHeaderTimeout(t time.Duration) func(*Conn) {
}
}

// WithBufferSize sets the size of the read buffer used for proxy header detection.
// Values <= 0 are ignored and the default (256 bytes) is used. Values < 16 are
// effectively 16 due to bufio's minimum. The default is tuned for typical proxy
// protocol header lengths.
func WithBufferSize(length int) func(*Conn) {
return func(c *Conn) {
if length <= 0 {
return
}
p := new(int)
*p = length
c.bufferSize = p
c.bufReader = bufio.NewReaderSize(c.conn, length)
}
}

// Accept waits for and returns the next valid connection to the listener.
func (p *Listener) Accept() (net.Conn, error) {
for {
Expand Down Expand Up @@ -130,11 +155,14 @@ func (p *Listener) Accept() (net.Conn, error) {
}
}

newConn := NewConn(
conn,
opts := []func(*Conn){
WithPolicy(proxyHeaderPolicy),
ValidateHeader(p.ValidateHeader),
)
}
if p.ReadBufferSize > 0 {
opts = append(opts, WithBufferSize(p.ReadBufferSize))
}
newConn := NewConn(conn, opts...)

// If the ReadHeaderTimeout for the listener is unset, use the default timeout.
if p.ReadHeaderTimeout == 0 {
Expand Down
110 changes: 110 additions & 0 deletions protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,116 @@ func TestNewConnSetReadHeaderTimeoutIgnoresNegative(t *testing.T) {
}
}

func TestWithBufferSizePositive(t *testing.T) {
conn, peer := net.Pipe()
t.Cleanup(func() {
_ = conn.Close()
_ = peer.Close()
})

proxyConn := NewConn(conn, WithBufferSize(4096))
if proxyConn.bufferSize == nil {
t.Fatalf("expected bufferSize to be set")
}
if *proxyConn.bufferSize != 4096 {
t.Fatalf("expected bufferSize 4096, got %d", *proxyConn.bufferSize)
}

go func() { _, _ = peer.Write([]byte("x")) }()
buf := make([]byte, 1)
if _, err := proxyConn.Read(buf); err != nil {
t.Fatalf("read failed: %v", err)
}
if string(buf) != "x" {
t.Fatalf("unexpected read: %q", buf)
}
}

func TestWithBufferSizeZeroOrNegative(t *testing.T) {
for _, length := range []int{0, -1} {
t.Run(fmt.Sprint(length), func(t *testing.T) {
conn, peer := net.Pipe()
t.Cleanup(func() {
_ = conn.Close()
_ = peer.Close()
})

proxyConn := NewConn(conn, WithBufferSize(length))
if proxyConn.bufferSize != nil {
t.Fatalf("expected bufferSize to be nil for length %d", length)
}

go func() { _, _ = peer.Write([]byte("y")) }()
buf := make([]byte, 1)
if _, err := proxyConn.Read(buf); err != nil {
t.Fatalf("read failed: %v", err)
}
if string(buf) != "y" {
t.Fatalf("unexpected read: %q", buf)
}
})
}
}

func TestListenerReadBufferSizeApplied(t *testing.T) {
l, err := net.Listen("tcp", testLocalhostRandomPort)
if err != nil {
t.Fatalf("err: %v", err)
}
t.Cleanup(func() { _ = l.Close() })

pl := &Listener{Listener: l, ReadBufferSize: 4096}

go func() {
c, _ := net.Dial("tcp", pl.Addr().String())
if c != nil {
_ = c.Close()
}
}()

conn, err := pl.Accept()
if err != nil {
t.Fatalf("Accept: %v", err)
}
t.Cleanup(func() { _ = conn.Close() })

proxyConn := conn.(*Conn)
if proxyConn.bufferSize == nil {
t.Fatalf("expected bufferSize to be set when Listener.ReadBufferSize > 0")
}
if *proxyConn.bufferSize != 4096 {
t.Fatalf("expected bufferSize 4096, got %d", *proxyConn.bufferSize)
}
}

func TestListenerReadBufferSizeZeroUsesDefault(t *testing.T) {
l, err := net.Listen("tcp", testLocalhostRandomPort)
if err != nil {
t.Fatalf("err: %v", err)
}
t.Cleanup(func() { _ = l.Close() })

pl := &Listener{Listener: l, ReadBufferSize: 0}

go func() {
c, _ := net.Dial("tcp", pl.Addr().String())
if c != nil {
_ = c.Close()
}
}()

conn, err := pl.Accept()
if err != nil {
t.Fatalf("Accept: %v", err)
}
t.Cleanup(func() { _ = conn.Close() })

proxyConn := conn.(*Conn)
if proxyConn.bufferSize != nil {
t.Fatalf("expected bufferSize to be nil when Listener.ReadBufferSize is 0")
}
}

func TestReadHeaderTimeoutRespectsEarlierDeadline(t *testing.T) {
const (
headerTimeout = 200 * time.Millisecond
Expand Down
Loading