diff --git a/policy.go b/policy.go index c59e28b..42712d8 100644 --- a/policy.go +++ b/policy.go @@ -12,6 +12,8 @@ import ( // See below for the different policies. // // In case an error is returned the connection is denied. +// +// Deprecated: use ConnPolicyFunc instead. type PolicyFunc func(upstream net.Addr) (Policy, error) // ConnPolicyFunc can be used to decide whether to trust the PROXY info @@ -53,13 +55,13 @@ const ( SKIP ) -// SkipProxyHeaderForCIDR returns a PolicyFunc which can be used to accept a -// connection from a skipHeaderCIDR without requiring a PROXY header, e.g. +// ConnSkipProxyHeaderForCIDR returns a ConnPolicyFunc which can be used to accept +// a connection from a skipHeaderCIDR without requiring a PROXY header, e.g. // Kubernetes pods local traffic. The def is a policy to use when an upstream // address doesn't match the skipHeaderCIDR. -func SkipProxyHeaderForCIDR(skipHeaderCIDR *net.IPNet, def Policy) PolicyFunc { - return func(upstream net.Addr) (Policy, error) { - ip, err := ipFromAddr(upstream) +func ConnSkipProxyHeaderForCIDR(skipHeaderCIDR *net.IPNet, def Policy) ConnPolicyFunc { + return func(connOpts ConnPolicyOptions) (Policy, error) { + ip, err := ipFromAddr(connOpts.Upstream) if err != nil { return def, err } @@ -72,6 +74,19 @@ func SkipProxyHeaderForCIDR(skipHeaderCIDR *net.IPNet, def Policy) PolicyFunc { } } +// SkipProxyHeaderForCIDR returns a PolicyFunc which can be used to accept a +// connection from a skipHeaderCIDR without requiring a PROXY header, e.g. +// Kubernetes pods local traffic. The def is a policy to use when an upstream +// address doesn't match the skipHeaderCIDR. +// +// Deprecated: use ConnSkipProxyHeaderForCIDR instead. +func SkipProxyHeaderForCIDR(skipHeaderCIDR *net.IPNet, def Policy) PolicyFunc { + connPolicy := ConnSkipProxyHeaderForCIDR(skipHeaderCIDR, def) + return func(upstream net.Addr) (Policy, error) { + return connPolicy(ConnPolicyOptions{Upstream: upstream}) + } +} + // WithPolicy adds given policy to a connection when passed as option to NewConn() func WithPolicy(p Policy) func(*Conn) { return func(c *Conn) { @@ -79,29 +94,74 @@ func WithPolicy(p Policy) func(*Conn) { } } +// ConnLaxWhiteListPolicy returns a ConnPolicyFunc which decides whether the +// upstream ip is allowed to send a proxy header based on a list of allowed +// IP addresses and IP ranges. In case upstream IP is not in list the proxy +// header will be ignored. If one of the provided IP addresses or IP ranges +// is invalid it will return an error instead of a ConnPolicyFunc. +func ConnLaxWhiteListPolicy(allowed []string) (ConnPolicyFunc, error) { + allowFrom, err := parse(allowed) + if err != nil { + return nil, err + } + + return connWhitelistPolicy(allowFrom, IGNORE), nil +} + // LaxWhiteListPolicy returns a PolicyFunc which decides whether the // upstream ip is allowed to send a proxy header based on a list of allowed // IP addresses and IP ranges. In case upstream IP is not in list the proxy // header will be ignored. If one of the provided IP addresses or IP ranges // is invalid it will return an error instead of a PolicyFunc. +// +// Deprecated: use ConnLaxWhiteListPolicy instead. func LaxWhiteListPolicy(allowed []string) (PolicyFunc, error) { - allowFrom, err := parse(allowed) + connPolicy, err := ConnLaxWhiteListPolicy(allowed) if err != nil { return nil, err } - return whitelistPolicy(allowFrom, IGNORE), nil + return func(upstream net.Addr) (Policy, error) { + return connPolicy(ConnPolicyOptions{Upstream: upstream}) + }, nil +} + +// ConnMustLaxWhiteListPolicy returns a ConnLaxWhiteListPolicy but will panic +// if one of the provided IP addresses or IP ranges is invalid. +func ConnMustLaxWhiteListPolicy(allowed []string) ConnPolicyFunc { + pfunc, err := ConnLaxWhiteListPolicy(allowed) + if err != nil { + panic(err) + } + + return pfunc } // MustLaxWhiteListPolicy returns a LaxWhiteListPolicy but will panic if one // of the provided IP addresses or IP ranges is invalid. +// +// Deprecated: use ConnMustLaxWhiteListPolicy instead. func MustLaxWhiteListPolicy(allowed []string) PolicyFunc { - pfunc, err := LaxWhiteListPolicy(allowed) + connPolicy := ConnMustLaxWhiteListPolicy(allowed) + return func(upstream net.Addr) (Policy, error) { + return connPolicy(ConnPolicyOptions{Upstream: upstream}) + } +} + +// ConnStrictWhiteListPolicy returns a ConnPolicyFunc which decides whether the +// upstream ip is allowed to send a proxy header based on a list of allowed +// IP addresses and IP ranges. In case upstream IP is not in list reading on +// the connection will be refused on the first read. Please note: subsequent +// reads do not error. It is the task of the code using the connection to +// handle that case properly. If one of the provided IP addresses or IP +// ranges is invalid it will return an error instead of a ConnPolicyFunc. +func ConnStrictWhiteListPolicy(allowed []string) (ConnPolicyFunc, error) { + allowFrom, err := parse(allowed) if err != nil { - panic(err) + return nil, err } - return pfunc + return connWhitelistPolicy(allowFrom, REJECT), nil } // StrictWhiteListPolicy returns a PolicyFunc which decides whether the @@ -111,19 +171,23 @@ func MustLaxWhiteListPolicy(allowed []string) PolicyFunc { // reads do not error. It is the task of the code using the connection to // handle that case properly. If one of the provided IP addresses or IP // ranges is invalid it will return an error instead of a PolicyFunc. +// +// Deprecated: use ConnStrictWhiteListPolicy instead. func StrictWhiteListPolicy(allowed []string) (PolicyFunc, error) { - allowFrom, err := parse(allowed) + connPolicy, err := ConnStrictWhiteListPolicy(allowed) if err != nil { return nil, err } - return whitelistPolicy(allowFrom, REJECT), nil + return func(upstream net.Addr) (Policy, error) { + return connPolicy(ConnPolicyOptions{Upstream: upstream}) + }, nil } -// MustStrictWhiteListPolicy returns a StrictWhiteListPolicy but will panic +// ConnMustStrictWhiteListPolicy returns a ConnStrictWhiteListPolicy but will panic // if one of the provided IP addresses or IP ranges is invalid. -func MustStrictWhiteListPolicy(allowed []string) PolicyFunc { - pfunc, err := StrictWhiteListPolicy(allowed) +func ConnMustStrictWhiteListPolicy(allowed []string) ConnPolicyFunc { + pfunc, err := ConnStrictWhiteListPolicy(allowed) if err != nil { panic(err) } @@ -131,9 +195,20 @@ func MustStrictWhiteListPolicy(allowed []string) PolicyFunc { return pfunc } -func whitelistPolicy(allowed []func(net.IP) bool, def Policy) PolicyFunc { +// MustStrictWhiteListPolicy returns a StrictWhiteListPolicy but will panic +// if one of the provided IP addresses or IP ranges is invalid. +// +// Deprecated: use ConnMustStrictWhiteListPolicy instead. +func MustStrictWhiteListPolicy(allowed []string) PolicyFunc { + connPolicy := ConnMustStrictWhiteListPolicy(allowed) return func(upstream net.Addr) (Policy, error) { - upstreamIP, err := ipFromAddr(upstream) + return connPolicy(ConnPolicyOptions{Upstream: upstream}) + } +} + +func connWhitelistPolicy(allowed []func(net.IP) bool, def Policy) ConnPolicyFunc { + return func(connOpts ConnPolicyOptions) (Policy, error) { + upstreamIP, err := ipFromAddr(connOpts.Upstream) if err != nil { // something is wrong with the source IP, better reject the connection return REJECT, err diff --git a/policy_test.go b/policy_test.go index caa3164..330e89b 100644 --- a/policy_test.go +++ b/policy_test.go @@ -10,6 +10,11 @@ type failingAddr struct{} func (f failingAddr) Network() string { return "failing" } func (f failingAddr) String() string { return "failing" } +type invalidIPAddr struct{} + +func (i invalidIPAddr) Network() string { return "tcp" } +func (i invalidIPAddr) String() string { return "999.999.999.999:1234" } + func TestWhitelistPolicyReturnsErrorOnInvalidAddress(t *testing.T) { var cases = []struct { name string @@ -29,39 +34,76 @@ func TestWhitelistPolicyReturnsErrorOnInvalidAddress(t *testing.T) { } } +func TestWhitelistPolicyReturnsErrorOnInvalidIP(t *testing.T) { + policies := []struct { + name string + policy ConnPolicyFunc + }{ + {"conn strict whitelist policy", ConnMustStrictWhiteListPolicy([]string{"10.0.0.3"})}, + {"conn lax whitelist policy", ConnMustLaxWhiteListPolicy([]string{"10.0.0.3"})}, + } + + for _, tc := range policies { + t.Run(tc.name, func(t *testing.T) { + _, err := tc.policy(ConnPolicyOptions{Upstream: invalidIPAddr{}}) + if err == nil { + t.Fatal("Expected error, got none") + } + }) + } +} + func TestStrictWhitelistPolicyReturnsRejectWhenUpstreamIpAddrNotInWhitelist(t *testing.T) { - p := MustStrictWhiteListPolicy([]string{"10.0.0.2", "10.0.0.3", "10.0.0.4", "10.0.0.0/30"}) + var cases = []struct { + name string + policy PolicyFunc + }{ + {"strict whitelist policy", MustStrictWhiteListPolicy([]string{"10.0.0.2", "10.0.0.3", "10.0.0.4", "10.0.0.0/30"})}, + } upstream, err := net.ResolveTCPAddr("tcp", "10.0.0.5:45738") if err != nil { t.Fatalf("err: %v", err) } - policy, err := p(upstream) - if err != nil { - t.Fatalf("err: %v", err) - } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + policy, err := tc.policy(upstream) + if err != nil { + t.Fatalf("err: %v", err) + } - if policy != REJECT { - t.Fatalf("Expected policy REJECT, got %v", policy) + if policy != REJECT { + t.Fatalf("Expected policy REJECT, got %v", policy) + } + }) } } func TestLaxWhitelistPolicyReturnsIgnoreWhenUpstreamIpAddrNotInWhitelist(t *testing.T) { - p := MustLaxWhiteListPolicy([]string{"10.0.0.2", "10.0.0.3", "10.0.0.4", "10.0.0.0/30"}) + var cases = []struct { + name string + policy PolicyFunc + }{ + {"lax whitelist policy", MustLaxWhiteListPolicy([]string{"10.0.0.2", "10.0.0.3", "10.0.0.4", "10.0.0.0/30"})}, + } upstream, err := net.ResolveTCPAddr("tcp", "10.0.0.5:45738") if err != nil { t.Fatalf("err: %v", err) } - policy, err := p(upstream) - if err != nil { - t.Fatalf("err: %v", err) - } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + policy, err := tc.policy(upstream) + if err != nil { + t.Fatalf("err: %v", err) + } - if policy != IGNORE { - t.Fatalf("Expected policy IGNORE, got %v", policy) + if policy != IGNORE { + t.Fatalf("Expected policy IGNORE, got %v", policy) + } + }) } } @@ -122,30 +164,50 @@ func TestWhitelistPolicyReturnsUseWhenUpstreamIpAddrInWhitelistRange(t *testing. } func Test_CreateWhitelistPolicyWithInvalidCidrReturnsError(t *testing.T) { - _, err := StrictWhiteListPolicy([]string{"20/80"}) - if err == nil { - t.Error("Expected error, got none") + var cases = []struct { + name string + fn func() error + }{ + {"strict whitelist policy", func() error { + _, err := StrictWhiteListPolicy([]string{"20/80"}) + return err + }}, + {"lax whitelist policy", func() error { + _, err := LaxWhiteListPolicy([]string{"20/80"}) + return err + }}, } -} -func Test_CreateWhitelistPolicyWithInvalidIpAddressReturnsError(t *testing.T) { - _, err := StrictWhiteListPolicy([]string{"855.222.233.11"}) - if err == nil { - t.Error("Expected error, got none") + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if err := tc.fn(); err == nil { + t.Error("Expected error, got none") + } + }) } } -func Test_CreateLaxPolicyWithInvalidCidrReturnsError(t *testing.T) { - _, err := LaxWhiteListPolicy([]string{"20/80"}) - if err == nil { - t.Error("Expected error, got none") +func Test_CreateWhitelistPolicyWithInvalidIpAddressReturnsError(t *testing.T) { + var cases = []struct { + name string + fn func() error + }{ + {"strict whitelist policy", func() error { + _, err := StrictWhiteListPolicy([]string{"855.222.233.11"}) + return err + }}, + {"lax whitelist policy", func() error { + _, err := LaxWhiteListPolicy([]string{"855.222.233.11"}) + return err + }}, } -} -func Test_CreateLaxPolicyWithInvalidIpAddresseturnsError(t *testing.T) { - _, err := LaxWhiteListPolicy([]string{"855.222.233.11"}) - if err == nil { - t.Error("Expected error, got none") + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if err := tc.fn(); err == nil { + t.Error("Expected error, got none") + } + }) } } @@ -189,6 +251,39 @@ func Test_MustStrictWhiteListPolicyPanicsWithInvalidIpRange(t *testing.T) { MustStrictWhiteListPolicy([]string{"20/80"}) } +func TestWhiteListPolicyFuncsReturnPolicies(t *testing.T) { + strictPolicy, err := StrictWhiteListPolicy([]string{"10.0.0.3"}) + if err != nil { + t.Fatalf("err: %v", err) + } + + laxPolicy, err := LaxWhiteListPolicy([]string{"10.0.0.3"}) + if err != nil { + t.Fatalf("err: %v", err) + } + + upstream, err := net.ResolveTCPAddr("tcp", "10.0.0.3:45738") + if err != nil { + t.Fatalf("err: %v", err) + } + + policy, err := strictPolicy(upstream) + if err != nil { + t.Fatalf("err: %v", err) + } + if policy != USE { + t.Fatalf("Expected policy USE, got %v", policy) + } + + policy, err = laxPolicy(upstream) + if err != nil { + t.Fatalf("err: %v", err) + } + if policy != USE { + t.Fatalf("Expected policy USE, got %v", policy) + } +} + func TestSkipProxyHeaderForCIDR(t *testing.T) { _, cidr, _ := net.ParseCIDR("192.0.2.1/24") f := SkipProxyHeaderForCIDR(cidr, REJECT) @@ -212,6 +307,83 @@ func TestSkipProxyHeaderForCIDR(t *testing.T) { } } +func TestConnSkipProxyHeaderForCIDRReturnsErrorOnInvalidAddress(t *testing.T) { + _, cidr, _ := net.ParseCIDR("192.0.2.1/24") + policy := ConnSkipProxyHeaderForCIDR(cidr, IGNORE) + + result, err := policy(ConnPolicyOptions{Upstream: failingAddr{}}) + if err == nil { + t.Fatal("Expected error, got none") + } + if result != IGNORE { + t.Fatalf("Expected policy IGNORE, got %v", result) + } +} + +func TestConnSkipProxyHeaderForCIDR(t *testing.T) { + _, cidr, _ := net.ParseCIDR("192.0.2.1/24") + policy := ConnSkipProxyHeaderForCIDR(cidr, REJECT) + + upstream, _ := net.ResolveTCPAddr("tcp", "192.0.2.255:12345") + result, err := policy(ConnPolicyOptions{Upstream: upstream}) + if err != nil { + t.Fatalf("err: %v", err) + } + if result != SKIP { + t.Errorf("Expected a SKIP policy for the %s address", upstream) + } + + upstream, _ = net.ResolveTCPAddr("tcp", "8.8.8.8:12345") + result, err = policy(ConnPolicyOptions{Upstream: upstream}) + if err != nil { + t.Fatalf("err: %v", err) + } + if result != REJECT { + t.Errorf("Expected a REJECT policy for the %s address", upstream) + } +} + +func TestConnWhitelistPolicies(t *testing.T) { + var cases = []struct { + name string + policy ConnPolicyFunc + expectedUse Policy + expectedReject Policy + }{ + {"conn strict whitelist policy", ConnMustStrictWhiteListPolicy([]string{"10.0.0.3"}), USE, REJECT}, + {"conn lax whitelist policy", ConnMustLaxWhiteListPolicy([]string{"10.0.0.3"}), USE, IGNORE}, + } + + allowed, err := net.ResolveTCPAddr("tcp", "10.0.0.3:45738") + if err != nil { + t.Fatalf("err: %v", err) + } + denied, err := net.ResolveTCPAddr("tcp", "10.0.0.4:45738") + if err != nil { + t.Fatalf("err: %v", err) + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + policy, err := tc.policy(ConnPolicyOptions{Upstream: allowed}) + if err != nil { + t.Fatalf("err: %v", err) + } + if policy != tc.expectedUse { + t.Fatalf("Expected policy %v, got %v", tc.expectedUse, policy) + } + + policy, err = tc.policy(ConnPolicyOptions{Upstream: denied}) + if err != nil { + t.Fatalf("err: %v", err) + } + if policy != tc.expectedReject { + t.Fatalf("Expected policy %v, got %v", tc.expectedReject, policy) + } + }) + } +} + func TestTrustProxyHeaderFrom(t *testing.T) { upstream, err := net.ResolveTCPAddr("tcp", "10.0.0.3:45738") if err != nil { diff --git a/protocol_test.go b/protocol_test.go index 2928373..0b17746 100644 --- a/protocol_test.go +++ b/protocol_test.go @@ -704,11 +704,11 @@ func TestPanicIfPolicyAndConnPolicySet(t *testing.T) { }() conn, err := pl.Accept() if err != nil { - t.Fatalf("Expected the accept to panic but did not and error is returned, got %v", err) + t.Fatalf("expected the accept to panic but did not and error is returned, got %v", err) } if conn != nil { - t.Fatalf("xpected the accept to panic but did not, got %v", conn) + t.Fatalf("expected the accept to panic but did not, got %v", conn) } t.Fatalf("expected the accept to panic but did not") }