diff --git a/policy.go b/policy.go index 71ad62b..6d505be 100644 --- a/policy.go +++ b/policy.go @@ -32,8 +32,31 @@ const ( // a PROXY header is not present, subsequent reads do not. It is the task // of the code using the connection to handle that case properly. REQUIRE + // SKIP accepts a connection without requiring the PROXY header + // Note: an example usage can be found in the SkipProxyHeaderForCIDR + // function. + SKIP ) +// SkipProxyHeaderForCIDR returns a PolicyFunc which can be used to accept a +// connection from a skipHeaderCIDR without requiring a PROXY header, e.g. +// Kubernetes pods local traffic. The def is a policy to use when an upstream +// address doesn't match the skipHeaderCIDR. +func SkipProxyHeaderForCIDR(skipHeaderCIDR *net.IPNet, def Policy) PolicyFunc { + return func(upstream net.Addr) (Policy, error) { + ip, err := ipFromAddr(upstream) + if err != nil { + return def, err + } + + if skipHeaderCIDR != nil && skipHeaderCIDR.Contains(ip) { + return SKIP, nil + } + + return def, nil + } +} + // WithPolicy adds given policy to a connection when passed as option to NewConn() func WithPolicy(p Policy) func(*Conn) { return func(c *Conn) { diff --git a/policy_test.go b/policy_test.go index 40a9444..c8b2624 100644 --- a/policy_test.go +++ b/policy_test.go @@ -188,3 +188,26 @@ func Test_MustStrictWhiteListPolicyPanicsWithInvalidIpRange(t *testing.T) { MustStrictWhiteListPolicy([]string{"20/80"}) } + +func TestSkipProxyHeaderForCIDR(t *testing.T) { + _, cidr, _ := net.ParseCIDR("192.0.2.1/24") + f := SkipProxyHeaderForCIDR(cidr, REJECT) + + upstream, _ := net.ResolveTCPAddr("tcp", "192.0.2.255:12345") + policy, err := f(upstream) + if err != nil { + t.Fatalf("err: %v", err) + } + if policy != SKIP { + t.Errorf("Expected a SKIP policy for the %s address", upstream) + } + + upstream, _ = net.ResolveTCPAddr("tcp", "8.8.8.8:12345") + policy, err = f(upstream) + if err != nil { + t.Fatalf("err: %v", err) + } + if policy != REJECT { + t.Errorf("Expected a REJECT policy for the %s address", upstream) + } +} diff --git a/protocol.go b/protocol.go index d299391..0cc7db0 100644 --- a/protocol.go +++ b/protocol.go @@ -74,6 +74,10 @@ func (p *Listener) Accept() (net.Conn, error) { conn.Close() return nil, err } + // Handle a connection as a regular one + if proxyHeaderPolicy == SKIP { + return conn, nil + } } newConn := NewConn( diff --git a/protocol_test.go b/protocol_test.go index 888c735..5311aae 100644 --- a/protocol_test.go +++ b/protocol_test.go @@ -923,6 +923,59 @@ func TestReadingIsRefusedOnErrorWhenLocalAddrRequestedFirst(t *testing.T) { } } +func TestSkipProxyProtocolPolicy(t *testing.T) { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %v", err) + } + + policyFunc := func(upstream net.Addr) (Policy, error) { return SKIP, nil } + + pl := &Listener{ + Listener: l, + Policy: policyFunc, + } + + cliResult := make(chan error) + ping := []byte("ping") + go func() { + conn, err := net.Dial("tcp", pl.Addr().String()) + if err != nil { + cliResult <- err + return + } + defer conn.Close() + conn.Write(ping) + close(cliResult) + }() + + conn, err := pl.Accept() + if err != nil { + t.Fatalf("err: %v", err) + } + defer conn.Close() + + _, ok := conn.(*net.TCPConn) + if !ok { + t.Fatal("err: should be a tcp connection") + } + _ = conn.LocalAddr() + recv := make([]byte, 4) + _, err = conn.Read(recv) + if err != nil { + t.Fatalf("Unexpected read error: %v", err) + } + + if !bytes.Equal(ping, recv) { + t.Fatalf("Unexpected %s data while expected %s", recv, ping) + } + + err = <-cliResult + if err != nil { + t.Fatalf("client error: %v", err) + } +} + func Test_ConnectionCasts(t *testing.T) { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil {