Skip to content

Commit 6a007c7

Browse files
committed
Add SKIP policy to not expect a PROXY header
1 parent 195fedc commit 6a007c7

4 files changed

Lines changed: 115 additions & 0 deletions

File tree

policy.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,31 @@ const (
3232
// a PROXY header is not present, subsequent reads do not. It is the task
3333
// of the code using the connection to handle that case properly.
3434
REQUIRE
35+
// SKIP accepts a connection without requiring the PROXY header
36+
// Note: an example usage can be found in the SkipProxyHeaderForCIDR
37+
// function.
38+
SKIP
3539
)
3640

41+
// SkipProxyHeaderForCIDR returns a PolicyFunc which can be used to accept a
42+
// connection from a skipHeaderCIDR without requiring a PROXY header, e.g.
43+
// Kubernetes pods local traffic. The def is a policy to use when an upstream
44+
// address doesn't match the skipHeaderCIDR.
45+
func SkipProxyHeaderForCIDR(skipHeaderCIDR *net.IPNet, def Policy) PolicyFunc {
46+
return func(upstream net.Addr) (Policy, error) {
47+
ip, err := ipFromAddr(upstream)
48+
if err != nil {
49+
return def, err
50+
}
51+
52+
if skipHeaderCIDR != nil && skipHeaderCIDR.Contains(ip) {
53+
return SKIP, nil
54+
}
55+
56+
return def, nil
57+
}
58+
}
59+
3760
// WithPolicy adds given policy to a connection when passed as option to NewConn()
3861
func WithPolicy(p Policy) func(*Conn) {
3962
return func(c *Conn) {

policy_test.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,3 +188,26 @@ func Test_MustStrictWhiteListPolicyPanicsWithInvalidIpRange(t *testing.T) {
188188

189189
MustStrictWhiteListPolicy([]string{"20/80"})
190190
}
191+
192+
func TestSkipProxyHeaderForCIDR(t *testing.T) {
193+
_, cidr, _ := net.ParseCIDR("192.0.2.1/24")
194+
f := SkipProxyHeaderForCIDR(cidr, REJECT)
195+
196+
upstream, _ := net.ResolveTCPAddr("tcp", "192.0.2.255:12345")
197+
policy, err := f(upstream)
198+
if err != nil {
199+
t.Fatalf("err: %v", err)
200+
}
201+
if policy != SKIP {
202+
t.Errorf("Expected a SKIP policy for the %s address", upstream)
203+
}
204+
205+
upstream, _ = net.ResolveTCPAddr("tcp", "8.8.8.8:12345")
206+
policy, err = f(upstream)
207+
if err != nil {
208+
t.Fatalf("err: %v", err)
209+
}
210+
if policy != REJECT {
211+
t.Errorf("Expected a REJECT policy for the %s address", upstream)
212+
}
213+
}

protocol.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,10 @@ func (p *Listener) Accept() (net.Conn, error) {
7474
conn.Close()
7575
return nil, err
7676
}
77+
// Handle a connection as a regular one
78+
if proxyHeaderPolicy == SKIP {
79+
return conn, nil
80+
}
7781
}
7882

7983
newConn := NewConn(

protocol_test.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,71 @@ func TestReadingIsRefusedOnErrorWhenLocalAddrRequestedFirst(t *testing.T) {
814814
}
815815
}
816816

817+
func TestSkipProxyProtocolPolicy(t *testing.T) {
818+
l, err := net.Listen("tcp", "127.0.0.1:0")
819+
if err != nil {
820+
t.Fatalf("err: %v", err)
821+
}
822+
823+
policyFunc := func(upstream net.Addr) (Policy, error) { return SKIP, nil }
824+
825+
timeout := time.Minute
826+
pl := &Listener{
827+
Listener: l,
828+
Policy: policyFunc,
829+
ReadHeaderTimeout: timeout,
830+
}
831+
832+
ticker := time.NewTicker(timeout)
833+
done := make(chan bool)
834+
defer func() {
835+
close(done)
836+
ticker.Stop()
837+
}()
838+
839+
go func() {
840+
for {
841+
select {
842+
case <-done:
843+
return
844+
case <-ticker.C:
845+
t.Fatalf("Timeout waiting for traffic")
846+
}
847+
}
848+
}()
849+
850+
ping := []byte("ping")
851+
go func() {
852+
conn, err := net.Dial("tcp", pl.Addr().String())
853+
if err != nil {
854+
t.Fatalf("err: %v", err)
855+
}
856+
defer conn.Close()
857+
conn.Write(ping)
858+
}()
859+
860+
conn, err := pl.Accept()
861+
if err != nil {
862+
t.Fatalf("err: %v", err)
863+
}
864+
defer conn.Close()
865+
866+
_, ok := conn.(*net.TCPConn)
867+
if !ok {
868+
t.Fatal("err: should be a tcp connection")
869+
}
870+
_ = conn.LocalAddr()
871+
recv := make([]byte, 4)
872+
_, err = conn.Read(recv)
873+
if err != nil {
874+
t.Fatalf("Unexpected read error: %v", err)
875+
}
876+
877+
if !bytes.Equal(ping, recv) {
878+
t.Fatalf("Unexpected %s data while expected %s", recv, ping)
879+
}
880+
}
881+
817882
func Test_ConnectionCasts(t *testing.T) {
818883
l, err := net.Listen("tcp", "127.0.0.1:0")
819884
if err != nil {

0 commit comments

Comments
 (0)