From 98ac070dc9be85c64e5a3fdba94356429dc4439a Mon Sep 17 00:00:00 2001 From: Nicola Murino Date: Sat, 10 Aug 2024 22:39:43 +0200 Subject: [PATCH] policy: add REFUSE in strict whitelist policies we want to refuse a connection from a not allowed upstream address whether the proxy header is set or not set. Before this change if the upstream address is not allowed: 1) if the policy returns REJECT, the connection is allowed if no proxy header is sent 2) if the policy returns REQUIRE, the connection is allowed if a proxy header is set, even if the upstream address is not allowed to set it. The new REFUSE policy can be returned for not allowed addresses so that the connection is always refused. --- policy.go | 5 +- policy_test.go | 4 +- protocol.go | 4 +- protocol_test.go | 152 +++++++++++++++++++++++++---------------------- 4 files changed, 88 insertions(+), 77 deletions(-) diff --git a/policy.go b/policy.go index ebef8b9..21b076c 100644 --- a/policy.go +++ b/policy.go @@ -51,6 +51,9 @@ const ( // Note: an example usage can be found in the SkipProxyHeaderForCIDR // function. SKIP + // REFUSE is the same as REJECT if a proxy header is set and the same as + // REQUIRE if a proxy header is not set. + REFUSE ) // SkipProxyHeaderForCIDR returns a PolicyFunc which can be used to accept a @@ -117,7 +120,7 @@ func StrictWhiteListPolicy(allowed []string) (PolicyFunc, error) { return nil, err } - return whitelistPolicy(allowFrom, REJECT), nil + return whitelistPolicy(allowFrom, REFUSE), nil } // MustStrictWhiteListPolicy returns a StrictWhiteListPolicy but will panic diff --git a/policy_test.go b/policy_test.go index a888bdd..d4c808e 100644 --- a/policy_test.go +++ b/policy_test.go @@ -42,8 +42,8 @@ func TestStrictWhitelistPolicyReturnsRejectWhenUpstreamIpAddrNotInWhitelist(t *t t.Fatalf("err: %v", err) } - if policy != REJECT { - t.Fatalf("Expected policy REJECT, got %v", policy) + if policy != REFUSE { + t.Fatalf("Expected policy REFUSE, got %v", policy) } } diff --git a/protocol.go b/protocol.go index 658900a..178bfdc 100644 --- a/protocol.go +++ b/protocol.go @@ -288,7 +288,7 @@ func (p *Conn) readHeader() error { // let's act as if there was no error when PROXY protocol is not present. if err == ErrNoProxyProtocol { // but not if it is required that the connection has one - if p.ProxyHeaderPolicy == REQUIRE { + if p.ProxyHeaderPolicy == REQUIRE || p.ProxyHeaderPolicy == REFUSE { return err } @@ -298,7 +298,7 @@ func (p *Conn) readHeader() error { // proxy protocol header was found if err == nil && header != nil { switch p.ProxyHeaderPolicy { - case REJECT: + case REJECT, REFUSE: // this connection is not allowed to send one return ErrSuperfluousProxyHeader case USE, REQUIRE: diff --git a/protocol_test.go b/protocol_test.go index fd976d1..0e9d1a9 100644 --- a/protocol_test.go +++ b/protocol_test.go @@ -752,100 +752,108 @@ func TestAcceptReturnsErrorWhenConnPolicyFuncErrors(t *testing.T) { } func TestReadingIsRefusedWhenProxyHeaderRequiredButMissing(t *testing.T) { - l, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("err: %v", err) + policyFuncs := []PolicyFunc{ + func(upstream net.Addr) (Policy, error) { return REQUIRE, nil }, + func(upstream net.Addr) (Policy, error) { return REFUSE, nil }, } + for _, policyFunc := range policyFuncs { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %v", err) + } - policyFunc := func(upstream net.Addr) (Policy, error) { return REQUIRE, nil } + pl := &Listener{Listener: l, Policy: policyFunc} - pl := &Listener{Listener: l, Policy: policyFunc} + cliResult := make(chan error) + go func() { + conn, err := net.Dial("tcp", pl.Addr().String()) + if err != nil { + cliResult <- err + return + } + defer conn.Close() - cliResult := make(chan error) - go func() { - conn, err := net.Dial("tcp", pl.Addr().String()) + if _, err := conn.Write([]byte("ping")); err != nil { + cliResult <- err + return + } + + close(cliResult) + }() + + conn, err := pl.Accept() if err != nil { - cliResult <- err - return + t.Fatalf("err: %v", err) } defer conn.Close() - if _, err := conn.Write([]byte("ping")); err != nil { - cliResult <- err - return + recv := make([]byte, 4) + if _, err = conn.Read(recv); err != ErrNoProxyProtocol { + t.Fatalf("Expected error %v, received %v", ErrNoProxyProtocol, err) + } + err = <-cliResult + if err != nil { + t.Fatalf("client error: %v", err) } - - close(cliResult) - }() - - conn, err := pl.Accept() - if err != nil { - t.Fatalf("err: %v", err) - } - defer conn.Close() - - recv := make([]byte, 4) - if _, err = conn.Read(recv); err != ErrNoProxyProtocol { - t.Fatalf("Expected error %v, received %v", ErrNoProxyProtocol, err) - } - err = <-cliResult - if err != nil { - t.Fatalf("client error: %v", err) } } func TestReadingIsRefusedWhenProxyHeaderPresentButNotAllowed(t *testing.T) { - l, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("err: %v", err) + policyFuncs := []PolicyFunc{ + func(upstream net.Addr) (Policy, error) { return REJECT, nil }, + func(upstream net.Addr) (Policy, error) { return REFUSE, nil }, } + for _, policyFunc := range policyFuncs { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %v", err) + } - policyFunc := func(upstream net.Addr) (Policy, error) { return REJECT, nil } + pl := &Listener{Listener: l, Policy: policyFunc} - pl := &Listener{Listener: l, Policy: policyFunc} + cliResult := make(chan error) + go func() { + conn, err := net.Dial("tcp", pl.Addr().String()) + if err != nil { + cliResult <- err + return + } + defer conn.Close() + header := &Header{ + Version: 2, + Command: PROXY, + TransportProtocol: TCPv4, + SourceAddr: &net.TCPAddr{ + IP: net.ParseIP("10.1.1.1"), + Port: 1000, + }, + DestinationAddr: &net.TCPAddr{ + IP: net.ParseIP("20.2.2.2"), + Port: 2000, + }, + } + if _, err := header.WriteTo(conn); err != nil { + cliResult <- err + return + } - cliResult := make(chan error) - go func() { - conn, err := net.Dial("tcp", pl.Addr().String()) + close(cliResult) + }() + + conn, err := pl.Accept() if err != nil { - cliResult <- err - return + t.Fatalf("err: %v", err) } defer conn.Close() - header := &Header{ - Version: 2, - Command: PROXY, - TransportProtocol: TCPv4, - SourceAddr: &net.TCPAddr{ - IP: net.ParseIP("10.1.1.1"), - Port: 1000, - }, - DestinationAddr: &net.TCPAddr{ - IP: net.ParseIP("20.2.2.2"), - Port: 2000, - }, + + recv := make([]byte, 4) + if _, err = conn.Read(recv); err != ErrSuperfluousProxyHeader { + t.Fatalf("Expected error %v, received %v", ErrSuperfluousProxyHeader, err) } - if _, err := header.WriteTo(conn); err != nil { - cliResult <- err - return + err = <-cliResult + if err != nil { + t.Fatalf("client error: %v", err) } - - close(cliResult) - }() - - conn, err := pl.Accept() - if err != nil { - t.Fatalf("err: %v", err) - } - defer conn.Close() - - recv := make([]byte, 4) - if _, err = conn.Read(recv); err != ErrSuperfluousProxyHeader { - t.Fatalf("Expected error %v, received %v", ErrSuperfluousProxyHeader, err) - } - err = <-cliResult - if err != nil { - t.Fatalf("client error: %v", err) } } func TestIgnorePolicyIgnoresIpFromProxyHeader(t *testing.T) {