diff --git a/protocol.go b/protocol.go index b331866..af3be4e 100644 --- a/protocol.go +++ b/protocol.go @@ -150,8 +150,13 @@ func (p *Listener) Addr() net.Addr { return p.Listener.Addr() } -// NewConn is used to wrap a net.Conn that may be speaking -// the proxy protocol into a proxyproto.Conn. +// NewConn is used to wrap a net.Conn that may be speaking the PROXY protocol +// into a proxyproto.Conn. +// +// NOTE: NewConn may interfere with previously set ReadDeadline on the provided net.Conn, +// because it sets a temporary deadline when detecting and reading the PROXY protocol header. +// If you need to enforce a specific ReadDeadline on the connection, be sure to call Conn.SetReadDeadline +// again after NewConn returns, to restore your desired deadline. func NewConn(conn net.Conn, opts ...func(*Conn)) *Conn { // For v1 the header length is at most 108 bytes. // For v2 the header length is at most 52 bytes plus the length of the TLVs. @@ -176,11 +181,9 @@ func NewConn(conn net.Conn, opts ...func(*Conn)) *Conn { // the initial scan. If there is an error parsing the header, // it is returned and the socket is closed. func (p *Conn) Read(b []byte) (int, error) { - p.once.Do(func() { - p.readErr = p.readHeader() - }) - if p.readErr != nil { - return 0, p.readErr + // Ensure header processing runs at most once and surface any errors. + if err := p.ensureHeaderProcessed(); err != nil { + return 0, err } return p.reader.Read(b) @@ -188,6 +191,10 @@ func (p *Conn) Read(b []byte) (int, error) { // Write wraps original conn.Write. func (p *Conn) Write(b []byte) (int, error) { + // Ensure header processing has completed before writing. + if err := p.ensureHeaderProcessed(); err != nil { + return 0, err + } return p.conn.Write(b) } @@ -199,7 +206,8 @@ func (p *Conn) Close() error { // ProxyHeader returns the proxy protocol header, if any. If an error occurs // while reading the proxy header, nil is returned. func (p *Conn) ProxyHeader() *Header { - p.once.Do(func() { p.readErr = p.readHeader() }) + // Ensure header processing runs at most once. + _ = p.ensureHeaderProcessed() return p.header } @@ -210,7 +218,8 @@ func (p *Conn) ProxyHeader() *Header { // from the proxy header even if the proxy header itself is // syntactically correct. func (p *Conn) LocalAddr() net.Addr { - p.once.Do(func() { p.readErr = p.readHeader() }) + // Ensure header processing runs at most once. + _ = p.ensureHeaderProcessed() if p.header == nil || p.header.Command.IsLocal() || p.readErr != nil { return p.conn.LocalAddr() } @@ -225,7 +234,8 @@ func (p *Conn) LocalAddr() net.Addr { // from the proxy header even if the proxy header itself is // syntactically correct. func (p *Conn) RemoteAddr() net.Addr { - p.once.Do(func() { p.readErr = p.readHeader() }) + // Ensure header processing runs at most once. + _ = p.ensureHeaderProcessed() if p.header == nil || p.header.Command.IsLocal() || p.readErr != nil { return p.conn.RemoteAddr() } @@ -291,11 +301,25 @@ func (p *Conn) SetWriteDeadline(t time.Time) error { // readHeader reads the proxy protocol header from the connection. func (p *Conn) readHeader() error { // If the connection's readHeaderTimeout is more than 0, - // push our deadline back to now plus the timeout. This should only - // run on the connection, as we don't want to override the previous - // read deadline the user may have used. + // apply a temporary deadline without extending a user-configured + // deadline. If the user has no deadline, we use now + timeout. if p.readHeaderTimeout > 0 { - if err := p.conn.SetReadDeadline(time.Now().Add(p.readHeaderTimeout)); err != nil { + var ( + storedDeadline time.Time + hasDeadline bool + ) + if t := p.readDeadline.Load(); t != nil { + storedDeadline = t.(time.Time) + hasDeadline = !storedDeadline.IsZero() + } + + headerDeadline := time.Now().Add(p.readHeaderTimeout) + if hasDeadline && storedDeadline.Before(headerDeadline) { + // Clamp to the user's earlier deadline to avoid extending it. + headerDeadline = storedDeadline + } + + if err := p.conn.SetReadDeadline(headerDeadline); err != nil { return err } } @@ -304,7 +328,7 @@ func (p *Conn) readHeader() error { // If the connection's readHeaderTimeout is more than 0, undo the change to the // deadline that we made above. Because we retain the readDeadline as part of our - // SetReadDeadline override, we know the user's desired deadline so we use that. + // SetReadDeadline override, we can restore the user's deadline (if any). // Therefore, we check whether the error is a net.Timeout and if it is, we decide // the proxy proto does not exist and set the error accordingly. if p.readHeaderTimeout > 0 { @@ -352,8 +376,23 @@ func (p *Conn) readHeader() error { return err } +// ensureHeaderProcessed runs header processing once. +func (p *Conn) ensureHeaderProcessed() error { + p.once.Do(func() { + p.readErr = p.readHeader() + }) + if p.readErr != nil { + return p.readErr + } + return nil +} + // ReadFrom implements the io.ReaderFrom ReadFrom method. func (p *Conn) ReadFrom(r io.Reader) (int64, error) { + // Ensure header processing has completed before reading/writing. + if err := p.ensureHeaderProcessed(); err != nil { + return 0, err + } if rf, ok := p.conn.(io.ReaderFrom); ok { return rf.ReadFrom(r) } @@ -362,9 +401,9 @@ func (p *Conn) ReadFrom(r io.Reader) (int64, error) { // WriteTo implements io.WriterTo. func (p *Conn) WriteTo(w io.Writer) (int64, error) { - p.once.Do(func() { p.readErr = p.readHeader() }) - if p.readErr != nil { - return 0, p.readErr + // Ensure header processing has completed before reading/writing. + if err := p.ensureHeaderProcessed(); err != nil { + return 0, err } b := make([]byte, p.bufReader.Buffered()) diff --git a/protocol_test.go b/protocol_test.go index 2e4fd03..1124e5d 100644 --- a/protocol_test.go +++ b/protocol_test.go @@ -223,6 +223,158 @@ func TestUseWithReadHeaderTimeout(t *testing.T) { } } +func TestNewConnSetReadHeaderTimeoutOption(t *testing.T) { + conn, peer := net.Pipe() + t.Cleanup(func() { + if closeErr := conn.Close(); closeErr != nil { + t.Errorf("failed to close connection: %v", closeErr) + } + }) + t.Cleanup(func() { + if closeErr := peer.Close(); closeErr != nil { + t.Errorf("failed to close peer connection: %v", closeErr) + } + }) + + // Ensure SetReadHeaderTimeout sets the connection-specific timeout. + timeout := 150 * time.Millisecond + proxyConn := NewConn(conn, SetReadHeaderTimeout(timeout)) + if proxyConn.readHeaderTimeout != timeout { + t.Fatalf("expected readHeaderTimeout %v, got %v", timeout, proxyConn.readHeaderTimeout) + } +} + +func TestNewConnSetReadHeaderTimeoutIgnoresNegative(t *testing.T) { + conn, peer := net.Pipe() + t.Cleanup(func() { + if closeErr := conn.Close(); closeErr != nil { + t.Errorf("failed to close connection: %v", closeErr) + } + }) + t.Cleanup(func() { + if closeErr := peer.Close(); closeErr != nil { + t.Errorf("failed to close peer connection: %v", closeErr) + } + }) + + // Negative values should be ignored, leaving the timeout unset. + proxyConn := NewConn(conn, SetReadHeaderTimeout(-1)) + if proxyConn.readHeaderTimeout != 0 { + t.Fatalf("expected readHeaderTimeout to remain 0, got %v", proxyConn.readHeaderTimeout) + } +} + +func TestReadHeaderTimeoutRespectsEarlierDeadline(t *testing.T) { + const ( + headerTimeout = 200 * time.Millisecond + userTimeout = 60 * time.Millisecond + tolerance = 100 * time.Millisecond + ) + + l, err := net.Listen("tcp", testLocalhostRandomPort) + if err != nil { + t.Fatalf("err: %v", err) + } + + pl := &Listener{ + Listener: l, + ReadHeaderTimeout: headerTimeout, + Policy: func(_ net.Addr) (Policy, error) { + // Use REQUIRE so a timeout is surfaced as ErrNoProxyProtocol. + return REQUIRE, nil + }, + } + + type dialResult struct { + conn net.Conn + err error + } + + dialResultCh := make(chan dialResult, 1) + go func() { + conn, err := net.Dial("tcp", pl.Addr().String()) + dialResultCh <- dialResult{conn: conn, err: err} + }() + + conn, err := pl.Accept() + if err != nil { + t.Fatalf("err: %v", err) + } + t.Cleanup(func() { + if closeErr := conn.Close(); closeErr != nil { + t.Errorf("failed to close connection: %v", closeErr) + } + }) + + result := <-dialResultCh + if result.err != nil { + t.Fatalf("client error: %v", result.err) + } + t.Cleanup(func() { + if closeErr := result.conn.Close(); closeErr != nil { + t.Errorf("failed to close client connection: %v", closeErr) + } + }) + + // Set a shorter user deadline than the readHeaderTimeout and do not send data. + if err := conn.SetReadDeadline(time.Now().Add(userTimeout)); err != nil { + t.Fatalf("err: %v", err) + } + + start := time.Now() + recv := make([]byte, 1) + _, err = conn.Read(recv) + elapsed := time.Since(start) + + // The read should honor the earlier user deadline instead of waiting + // for the longer readHeaderTimeout. + if !errors.Is(err, ErrNoProxyProtocol) { + t.Fatalf("expected ErrNoProxyProtocol, got: %v", err) + } + if elapsed > userTimeout+tolerance { + t.Fatalf("read exceeded user deadline: elapsed=%v timeout=%v", elapsed, userTimeout) + } +} + +func TestDeadlineSettersAfterHeaderProcessed(t *testing.T) { + conn, peer := net.Pipe() + t.Cleanup(func() { + if closeErr := conn.Close(); closeErr != nil { + t.Errorf("failed to close connection: %v", closeErr) + } + }) + t.Cleanup(func() { + if closeErr := peer.Close(); closeErr != nil { + t.Errorf("failed to close peer connection: %v", closeErr) + } + }) + + proxyConn := NewConn(conn) + + // Ensure header processing completes by sending a non-PROXY byte + // and reading it through the proxy connection. + go func() { + if _, err := peer.Write([]byte("x")); err != nil { + t.Errorf("failed to write peer data: %v", err) + } + }() + buf := make([]byte, 1) + if _, err := proxyConn.Read(buf); err != nil { + t.Fatalf("read failed: %v", err) + } + + deadline := time.Now().Add(time.Second) + if err := proxyConn.SetDeadline(deadline); err != nil { + t.Fatalf("unexpected SetDeadline error: %v", err) + } + if err := proxyConn.SetReadDeadline(deadline); err != nil { + t.Fatalf("unexpected SetReadDeadline error: %v", err) + } + if err := proxyConn.SetWriteDeadline(deadline); err != nil { + t.Fatalf("unexpected SetWriteDeadline error: %v", err) + } +} + func TestReadHeaderTimeoutIsReset(t *testing.T) { const timeout = time.Millisecond * 250 @@ -1927,6 +2079,54 @@ type testConn struct { net.Conn // nil; crash on any unexpected use } +type deadlineConn struct { + deadline time.Time + readDeadline time.Time + writeDeadline time.Time +} + +func (c *deadlineConn) Read(_ []byte) (int, error) { return 0, io.EOF } +func (c *deadlineConn) Write(p []byte) (int, error) { return len(p), nil } +func (c *deadlineConn) Close() error { return nil } +func (c *deadlineConn) LocalAddr() net.Addr { return dummyAddr("local") } +func (c *deadlineConn) RemoteAddr() net.Addr { return dummyAddr("remote") } +func (c *deadlineConn) SetDeadline(t time.Time) error { + c.deadline = t + return nil +} +func (c *deadlineConn) SetReadDeadline(t time.Time) error { + c.readDeadline = t + return nil +} +func (c *deadlineConn) SetWriteDeadline(t time.Time) error { + c.writeDeadline = t + return nil +} + +type noReadFromConn struct { + written bytes.Buffer +} + +func (c *noReadFromConn) Read(_ []byte) (int, error) { return 0, io.EOF } +func (c *noReadFromConn) Write(p []byte) (int, error) { + return c.written.Write(p) +} +func (c *noReadFromConn) Close() error { return nil } +func (c *noReadFromConn) LocalAddr() net.Addr { return dummyAddr("local") } +func (c *noReadFromConn) RemoteAddr() net.Addr { return dummyAddr("remote") } +func (c *noReadFromConn) SetDeadline(time.Time) error { return nil } +func (c *noReadFromConn) SetReadDeadline(time.Time) error { + return nil +} +func (c *noReadFromConn) SetWriteDeadline(time.Time) error { + return nil +} + +type dummyAddr string + +func (a dummyAddr) Network() string { return "dummy" } +func (a dummyAddr) String() string { return string(a) } + func (c *testConn) ReadFrom(r io.Reader) (int64, error) { c.readFromCalledWith = r b, err := io.ReadAll(r) @@ -1984,6 +2184,127 @@ func TestCopyFromWrappedConnectionToWrappedConnection(t *testing.T) { } } +func TestDeadlineWrappersDelegate(t *testing.T) { + conn := &deadlineConn{} + proxyConn := NewConn(conn) + + deadline := time.Now().Add(2 * time.Second) + readDeadline := time.Now().Add(3 * time.Second) + writeDeadline := time.Now().Add(4 * time.Second) + + // Ensure deadline setters pass through to the underlying connection. + if err := proxyConn.SetDeadline(deadline); err != nil { + t.Fatalf("unexpected SetDeadline error: %v", err) + } + if err := proxyConn.SetReadDeadline(readDeadline); err != nil { + t.Fatalf("unexpected SetReadDeadline error: %v", err) + } + if err := proxyConn.SetWriteDeadline(writeDeadline); err != nil { + t.Fatalf("unexpected SetWriteDeadline error: %v", err) + } + + if !conn.deadline.Equal(deadline) { + t.Fatalf("SetDeadline did not pass through value") + } + if !conn.readDeadline.Equal(readDeadline) { + t.Fatalf("SetReadDeadline did not pass through value") + } + if !conn.writeDeadline.Equal(writeDeadline) { + t.Fatalf("SetWriteDeadline did not pass through value") + } +} + +func TestReadFromFallbackCopiesToConn(t *testing.T) { + conn := &noReadFromConn{} + proxyConn := NewConn(conn) + + payload := []byte("payload") + if _, err := proxyConn.ReadFrom(bytes.NewReader(payload)); err != nil { + t.Fatalf("unexpected ReadFrom error: %v", err) + } + + // When the inner connection does not implement io.ReaderFrom, + // ReadFrom should fall back to io.Copy and write the payload. + if !bytes.Equal(conn.written.Bytes(), payload) { + t.Fatalf("unexpected write content: %q", conn.written.String()) + } +} + +func TestWriteToDrainsBufferedData(t *testing.T) { + l, err := net.Listen("tcp", testLocalhostRandomPort) + if err != nil { + t.Fatalf("err: %v", err) + } + + pl := &Listener{Listener: l} + + header := &Header{ + Version: 2, + Command: PROXY, + TransportProtocol: TCPv4, + SourceAddr: &net.TCPAddr{ + IP: net.ParseIP(testSourceIPv4Addr), + Port: 1000, + }, + DestinationAddr: &net.TCPAddr{ + IP: net.ParseIP(testDestinationIPv4Addr), + Port: 2000, + }, + } + + payload := []byte("ping") + + cliResult := make(chan error) + go func() { + conn, err := net.Dial("tcp", pl.Addr().String()) + if err != nil { + cliResult <- err + return + } + + // Write the header followed by payload to populate the reader buffer. + if _, err := header.WriteTo(conn); err != nil { + cliResult <- err + return + } + if _, err := conn.Write(payload); err != nil { + cliResult <- err + return + } + + // Close the client so WriteTo's io.Copy completes. + if err := conn.Close(); err != nil { + cliResult <- err + return + } + + close(cliResult) + }() + + conn, err := pl.Accept() + if err != nil { + t.Fatalf("err: %v", err) + } + t.Cleanup(func() { + if closeErr := conn.Close(); closeErr != nil { + t.Errorf("failed to close connection: %v", closeErr) + } + }) + + var out bytes.Buffer + if _, err := conn.(*Conn).WriteTo(&out); err != nil { + t.Fatalf("unexpected WriteTo error: %v", err) + } + if !bytes.Equal(out.Bytes(), payload) { + t.Fatalf("unexpected WriteTo output: %q", out.String()) + } + + err = <-cliResult + if err != nil { + t.Fatalf("client error: %v", err) + } +} + func benchmarkTCPProxy(size int, b *testing.B) { // create and start the echo backend backend, err := net.Listen("tcp", testLocalhostRandomPort)