Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 52 additions & 1 deletion protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package proxyproto

import (
"bufio"
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -155,7 +157,8 @@ func (p *Listener) Addr() net.Addr {
func NewConn(conn net.Conn, opts ...func(*Conn)) *Conn {
// For v1 the header length is at most 108 bytes.
// For v2 the header length is at most 52 bytes plus the length of the TLVs.
// We use 256 bytes to be safe.
// We start small to save memory on idle connections, but readHeader()
// will grow the buffer if a large v2 header is detected.
const bufSize = 256
br := bufio.NewReaderSize(conn, bufSize)

Expand Down Expand Up @@ -300,6 +303,24 @@ func (p *Conn) readHeader() error {
}
}

// Peek only the fixed 16-byte v2 preamble to decide whether we should
// expand the buffer for a large header. This is a best-effort check that
// does not consume bytes; it avoids allocations unless the header looks
// well-formed enough to trust its length field.
// We need 16 bytes: 12 (Signature) + 1 (Version/Command) + 1 (Transport) + 2 (Length).
if p.bufReader != nil {
peeked, err := p.bufReader.Peek(16)
if err == nil {
if totalSize, ok := v2PreflightHeaderSize(peeked); ok {
// If the header is larger than our current buffer, wrap the existing
// reader in a larger one. This preserves any bytes already buffered.
if totalSize > p.bufReader.Size() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can result in a DoS if the client sends 65535 as the length.

p.bufReader = bufio.NewReaderSize(p.bufReader, totalSize)
Copy link
Contributor

@emersion emersion Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Manually resizing the buffer seems more complicated than necessary: I don't think we need to do that. See #155 for a simpler version.

This also results in paying the memory cost for the TLVs twice: once for the internal bufio.Reader buffer, and another time when reading TLVs as a byte slice.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More generally, bufio.Reader's internal buffer size is not supposed to be client-controlled. It's supposed to act as a way to avoid large allocations.

}
}
}
}

header, err := Read(p.bufReader)

// If the connection's readHeaderTimeout is more than 0, undo the change to the
Expand Down Expand Up @@ -352,6 +373,36 @@ func (p *Conn) readHeader() error {
return err
}

// v2PreflightHeaderSize validates the v2 preamble (signature + command +
// transport + minimum length) and returns the total header size.
// It is intentionally conservative: if anything looks suspicious, it returns false
// so that we do not resize based on an attacker-controlled length field.
func v2PreflightHeaderSize(peeked []byte) (int, bool) {
if len(peeked) < 16 {
return 0, false
}
if !bytes.Equal(peeked[:12], SIGV2) {
return 0, false
}

command := ProtocolVersionAndCommand(peeked[12])
if _, ok := supportedCommand[command]; !ok {
return 0, false
}

transport := AddressFamilyAndProtocol(peeked[13])
if transport == UNSPEC && command != LOCAL {
return 0, false
}

length := binary.BigEndian.Uint16(peeked[14:16])
if !(&Header{TransportProtocol: transport}).validateLength(length) {
return 0, false
}

return 16 + int(length), true
}

// ReadFrom implements the io.ReaderFrom ReadFrom method.
func (p *Conn) ReadFrom(r io.Reader) (int64, error) {
if rf, ok := p.conn.(io.ReaderFrom); ok {
Expand Down
51 changes: 51 additions & 0 deletions protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"bytes"
"crypto/tls"
"crypto/x509"
"encoding/binary"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -1984,6 +1985,56 @@ func TestCopyFromWrappedConnectionToWrappedConnection(t *testing.T) {
}
}

func TestConnReadHeaderResizesForLargeV2(t *testing.T) {
const payloadLength = 400
length := uint16(payloadLength)

payload := make([]byte, payloadLength)
copy(payload[0:4], net.ParseIP(testSourceIPv4Addr).To4())
copy(payload[4:8], net.ParseIP(testDestinationIPv4Addr).To4())
binary.BigEndian.PutUint16(payload[8:10], 1000)
binary.BigEndian.PutUint16(payload[10:12], 2000)
for i := 12; i < len(payload); i++ {
payload[i] = 0x99
}

header := make([]byte, 16)
copy(header[:12], SIGV2)
header[12] = byte(PROXY)
header[13] = byte(TCPv4)
binary.BigEndian.PutUint16(header[14:16], length)
fullData := append(header, payload...)

serverConn, clientConn := net.Pipe()
t.Cleanup(func() {
if closeErr := serverConn.Close(); closeErr != nil {
t.Errorf("failed to close server connection: %v", closeErr)
}
if closeErr := clientConn.Close(); closeErr != nil {
t.Errorf("failed to close client connection: %v", closeErr)
}
})

go func() {
_, _ = clientConn.Write(fullData)
_ = clientConn.Close()
}()

conn := NewConn(serverConn)
_ = conn.SetReadDeadline(time.Now().Add(time.Second))

headerResult := conn.ProxyHeader()
if conn.readErr != nil {
t.Fatalf("unexpected read header error: %v", conn.readErr)
}
if headerResult == nil {
t.Fatalf("expected header, got nil")
}
if conn.bufReader == nil || conn.bufReader.Size() < 16+int(length) {
t.Fatalf("expected buffer size >= %d, got %d", 16+int(length), conn.bufReader.Size())
}
}

func benchmarkTCPProxy(size int, b *testing.B) {
// create and start the echo backend
backend, err := net.Listen("tcp", testLocalhostRandomPort)
Expand Down
73 changes: 73 additions & 0 deletions v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,79 @@ func TestParseV2Invalid(t *testing.T) {
}
}

func TestV2PreflightHeaderSize(t *testing.T) {
build := func(command ProtocolVersionAndCommand, transport AddressFamilyAndProtocol, length uint16) []byte {
header := make([]byte, 16)
copy(header[:12], SIGV2)
header[12] = byte(command)
header[13] = byte(transport)
binary.BigEndian.PutUint16(header[14:16], length)
return header
}

tests := []struct {
desc string
peeked []byte
wantSize int
wantAccept bool
}{
{
desc: "valid unspec local length zero",
peeked: build(LOCAL, UNSPEC, 0),
wantSize: 16,
wantAccept: true,
},
{
desc: "valid proxy tcpv4",
peeked: build(PROXY, TCPv4, lengthV4),
wantSize: 16 + int(lengthV4),
wantAccept: true,
},
{
desc: "short preamble",
peeked: SIGV2,
wantSize: 0,
wantAccept: false,
},
{
desc: "non v2 signature",
peeked: build(PROXY, TCPv4, lengthV4)[1:],
wantSize: 0,
wantAccept: false,
},
{
desc: "unsupported command",
peeked: build(ProtocolVersionAndCommand(invalidRune), TCPv4, lengthV4),
wantSize: 0,
wantAccept: false,
},
{
desc: "unspec with proxy command",
peeked: build(PROXY, UNSPEC, 0),
wantSize: 0,
wantAccept: false,
},
{
desc: "invalid length for tcpv4",
peeked: build(PROXY, TCPv4, 0),
wantSize: 0,
wantAccept: false,
},
}

for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
gotSize, gotAccept := v2PreflightHeaderSize(tt.peeked)
if gotAccept != tt.wantAccept {
t.Fatalf("accept=%v, want=%v", gotAccept, tt.wantAccept)
}
if gotSize != tt.wantSize {
t.Fatalf("size=%d, want=%d", gotSize, tt.wantSize)
}
})
}
}

var validParseAndWriteV2Tests = []struct {
desc string
reader *bufio.Reader
Expand Down
Loading