diff --git a/protocol.go b/protocol.go index b331866..9f52db0 100644 --- a/protocol.go +++ b/protocol.go @@ -155,8 +155,9 @@ func (p *Listener) Addr() net.Addr { func NewConn(conn net.Conn, opts ...func(*Conn)) *Conn { // For v1 the header length is at most 108 bytes. // For v2 the header length is at most 52 bytes plus the length of the TLVs. - // We use 256 bytes to be safe. - const bufSize = 256 + // PP2_SUBTYPE_SSL_CLIENT_CERT might be a few kilobytes. We use 4096 bytes + // to be safe. + const bufSize = 4096 br := bufio.NewReaderSize(conn, bufSize) pConn := &Conn{ diff --git a/v2_test.go b/v2_test.go index 11c29a3..16af742 100644 --- a/v2_test.go +++ b/v2_test.go @@ -15,7 +15,8 @@ var ( invalidRune = byte('\x99') // Lengths to use in tests. - lengthPadded = uint16(84) + lengthPadded = uint16(84) + lengthTLVsTooLarge = uint16(10 * 1024) lengthEmptyBytes = func() []byte { a := make([]byte, 2) @@ -27,6 +28,11 @@ var ( binary.BigEndian.PutUint16(a, lengthPadded) return a }() + lengthTLVsTooLargeBytes = func() []byte { + a := make([]byte, 2) + binary.BigEndian.PutUint16(a, lengthTLVsTooLarge) + return a + }() // If life gives you lemons, make mojitos. portBytes = func() []byte { @@ -60,9 +66,10 @@ var ( _, _ = iorand.Read(tlv) return tlv }() - fixtureIPv4V2TLV = fixtureWithTLV(lengthV4Bytes, fixtureIPv4Address, fixtureTLV) - fixtureIPv6V2TLV = fixtureWithTLV(lengthV6Bytes, fixtureIPv6Address, fixtureTLV) - fixtureUnspecTLV = fixtureWithTLV(lengthUnspecBytes, []byte{}, fixtureTLV) + fixtureIPv4V2TLV = fixtureWithTLV(lengthV4Bytes, fixtureIPv4Address, fixtureTLV) + fixtureIPv6V2TLV = fixtureWithTLV(lengthV6Bytes, fixtureIPv6Address, fixtureTLV) + fixtureUnspecTLV = fixtureWithTLV(lengthUnspecBytes, []byte{}, fixtureTLV) + fixtureTLVsTooLarge = append(append(lengthTLVsTooLargeBytes, fixtureIPv4Address...), make([]byte, lengthTLVsTooLarge-lengthV4)...) // Arbitrary bytes following proxy bytes. arbitraryTailBytes = []byte{'\x99', '\x97', '\x98'} @@ -153,6 +160,11 @@ var invalidParseV2Tests = []struct { reader: newBufioReader(append(append(SIGV2, byte(LOCAL), byte(UNSPEC)), fixtureUnspecTLV[:2]...)), expectedError: ErrInvalidLength, }, + { + desc: "TCPv4 with length too large", + reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv4)), fixtureTLVsTooLarge...)), + expectedError: ErrInvalidLength, + }, } func TestParseV2Invalid(t *testing.T) {