Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,31 @@ const (
// a PROXY header is not present, subsequent reads do not. It is the task
// of the code using the connection to handle that case properly.
REQUIRE
// SKIP accepts a connection without requiring the PROXY header
// Note: an example usage can be found in the SkipProxyHeaderForCIDR
// function.
SKIP
)

// SkipProxyHeaderForCIDR returns a PolicyFunc which can be used to accept a
// connection from a skipHeaderCIDR without requiring a PROXY header, e.g.
// Kubernetes pods local traffic. The def is a policy to use when an upstream
// address doesn't match the skipHeaderCIDR.
func SkipProxyHeaderForCIDR(skipHeaderCIDR *net.IPNet, def Policy) PolicyFunc {
return func(upstream net.Addr) (Policy, error) {
ip, err := ipFromAddr(upstream)
if err != nil {
return def, err
}

if skipHeaderCIDR != nil && skipHeaderCIDR.Contains(ip) {
return SKIP, nil
}

return def, nil
}
}

// WithPolicy adds given policy to a connection when passed as option to NewConn()
func WithPolicy(p Policy) func(*Conn) {
return func(c *Conn) {
Expand Down
23 changes: 23 additions & 0 deletions policy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,26 @@ func Test_MustStrictWhiteListPolicyPanicsWithInvalidIpRange(t *testing.T) {

MustStrictWhiteListPolicy([]string{"20/80"})
}

func TestSkipProxyHeaderForCIDR(t *testing.T) {
_, cidr, _ := net.ParseCIDR("192.0.2.1/24")
f := SkipProxyHeaderForCIDR(cidr, REJECT)

upstream, _ := net.ResolveTCPAddr("tcp", "192.0.2.255:12345")
policy, err := f(upstream)
if err != nil {
t.Fatalf("err: %v", err)
}
if policy != SKIP {
t.Errorf("Expected a SKIP policy for the %s address", upstream)
}

upstream, _ = net.ResolveTCPAddr("tcp", "8.8.8.8:12345")
policy, err = f(upstream)
if err != nil {
t.Fatalf("err: %v", err)
}
if policy != REJECT {
t.Errorf("Expected a REJECT policy for the %s address", upstream)
}
}
4 changes: 4 additions & 0 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ func (p *Listener) Accept() (net.Conn, error) {
conn.Close()
return nil, err
}
// Handle a connection as a regular one
if proxyHeaderPolicy == SKIP {
return conn, nil
}
}

newConn := NewConn(
Expand Down
53 changes: 53 additions & 0 deletions protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,59 @@ func TestReadingIsRefusedOnErrorWhenLocalAddrRequestedFirst(t *testing.T) {
}
}

func TestSkipProxyProtocolPolicy(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("err: %v", err)
}

policyFunc := func(upstream net.Addr) (Policy, error) { return SKIP, nil }

pl := &Listener{
Listener: l,
Policy: policyFunc,
}

cliResult := make(chan error)
ping := []byte("ping")
go func() {
conn, err := net.Dial("tcp", pl.Addr().String())
if err != nil {
cliResult <- err
return
}
defer conn.Close()
conn.Write(ping)
close(cliResult)
}()

conn, err := pl.Accept()
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()

_, ok := conn.(*net.TCPConn)
if !ok {
t.Fatal("err: should be a tcp connection")
}
_ = conn.LocalAddr()
recv := make([]byte, 4)
_, err = conn.Read(recv)
if err != nil {
t.Fatalf("Unexpected read error: %v", err)
}

if !bytes.Equal(ping, recv) {
t.Fatalf("Unexpected %s data while expected %s", recv, ping)
}

err = <-cliResult
if err != nil {
t.Fatalf("client error: %v", err)
}
}

func Test_ConnectionCasts(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
Expand Down