diff --git a/protocol_test.go b/protocol_test.go index 2e4fd03..769d573 100644 --- a/protocol_test.go +++ b/protocol_test.go @@ -1480,6 +1480,72 @@ func TestSkipProxyProtocolConnPolicy(t *testing.T) { } } +func TestLocalCommandUsesUnderlyingAddrs(t *testing.T) { + l, err := net.Listen("tcp", testLocalhostRandomPort) + if err != nil { + t.Fatalf("err: %v", err) + } + + pl := &Listener{Listener: l} + + header := &Header{ + Version: 2, + Command: LOCAL, + TransportProtocol: UNSPEC, + } + + cliResult := make(chan error) + go func() { + conn, err := net.Dial("tcp", pl.Addr().String()) + if err != nil { + cliResult <- err + return + } + + // Write a LOCAL header with no address information. + if _, err := header.WriteTo(conn); err != nil { + cliResult <- err + return + } + if _, err := conn.Write([]byte("ping")); err != nil { + cliResult <- err + return + } + + // Close client side to avoid leaving the connection open. + if err := conn.Close(); err != nil { + cliResult <- err + return + } + + close(cliResult) + }() + + conn, err := pl.Accept() + if err != nil { + t.Fatalf("err: %v", err) + } + t.Cleanup(func() { + if closeErr := conn.Close(); closeErr != nil { + t.Errorf("failed to close connection: %v", closeErr) + } + }) + + proxyConn := conn.(*Conn) + // LOCAL should make LocalAddr/RemoteAddr fall back to underlying addresses. + if proxyConn.LocalAddr().String() != proxyConn.Raw().LocalAddr().String() { + t.Fatalf("LocalAddr should use underlying address for LOCAL command") + } + if proxyConn.RemoteAddr().String() != proxyConn.Raw().RemoteAddr().String() { + t.Fatalf("RemoteAddr should use underlying address for LOCAL command") + } + + err = <-cliResult + if err != nil { + t.Fatalf("client error: %v", err) + } +} + func Test_ConnectionCasts(t *testing.T) { l, err := net.Listen("tcp", testLocalhostRandomPort) if err != nil {