Skip to content

Commit e41e72c

Browse files
authored
feat(middleware): Enhance rate limiter to support reverse proxies (#106)
Signed-off-by: ramsyana <[email protected]>
1 parent 2206928 commit e41e72c

File tree

3 files changed

+86
-19
lines changed

3 files changed

+86
-19
lines changed

cmd/app/server/main.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,11 @@ func NewApp() (*App, error) {
7171
sugar.Fatalf("Failed to load TLS credentials: %v", err)
7272
return nil, err
7373
}
74-
rateLimiter := middlewares.NewRateLimiter(5, 10)
74+
75+
// Initialize rate limiter with default trusted proxies
76+
trustedProxies := middlewares.DefaultTrustedProxies()
77+
sugar.Infof("Initializing rate limiter with trusted proxies: %v", trustedProxies)
78+
rateLimiter := middlewares.NewRateLimiter(5, 10, trustedProxies)
7579

7680
// Create the gRPC server with TLS and middleware.
7781
grpcServer := grpc.NewServer(

pkg/middlewares/middlewares_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ func TestCORSMiddleware(t *testing.T) {
4848

4949
// Test Rate Limiting Middleware
5050
func TestRateLimiter(t *testing.T) {
51-
limiter := NewRateLimiter(1, 1) // 1 request per second
51+
limiter := NewRateLimiter(1, 1, DefaultTrustedProxies()) // 1 request per second
5252

5353
clientID := "test-client"
5454

pkg/middlewares/rate_limiter.go

Lines changed: 80 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,42 +2,72 @@ package middlewares
22

33
import (
44
"context"
5+
"net"
6+
"strings"
57
"sync"
68
"time"
79

810
"golang.org/x/time/rate"
911
"google.golang.org/grpc"
1012
"google.golang.org/grpc/codes"
13+
"google.golang.org/grpc/metadata"
1114
"google.golang.org/grpc/peer"
1215
"google.golang.org/grpc/status"
1316
)
1417

1518
// RateLimiter structure
1619
type RateLimiter struct {
17-
mu sync.Mutex
18-
limiters map[string]*rate.Limiter
19-
rate rate.Limit
20-
burst int
20+
mu sync.Mutex
21+
limiters map[string]*rate.Limiter
22+
rate rate.Limit
23+
burst int
24+
trustedProxies map[string]bool
25+
maxLimiters int
2126
}
2227

23-
// NewRateLimiter initializes a rate limiter
24-
func NewRateLimiter(r rate.Limit, b int) *RateLimiter {
28+
// NewRateLimiter initializes a rate limiter with configurable rate, burst, and trusted proxies
29+
func NewRateLimiter(r rate.Limit, b int, proxies []string) *RateLimiter {
30+
trusted := make(map[string]bool)
31+
for _, proxy := range proxies {
32+
if isValidIP(proxy) {
33+
trusted[proxy] = true
34+
}
35+
}
36+
2537
return &RateLimiter{
26-
limiters: make(map[string]*rate.Limiter),
27-
rate: r,
28-
burst: b,
38+
limiters: make(map[string]*rate.Limiter),
39+
rate: r,
40+
burst: b,
41+
trustedProxies: trusted,
42+
maxLimiters: 10000, // Default maximum number of limiters to prevent memory leaks
2943
}
3044
}
3145

32-
// getLimiter gets or creates a rate limiter for a specific client
46+
// isValidIP checks if a string is a valid IP address
47+
func isValidIP(ip string) bool {
48+
return net.ParseIP(ip) != nil
49+
}
50+
51+
// GetLimiter gets or creates a rate limiter for a specific client
3352
func (r *RateLimiter) GetLimiter(clientID string) *rate.Limiter {
3453
r.mu.Lock()
3554
defer r.mu.Unlock()
3655

56+
// Return existing limiter if it exists
3757
if limiter, exists := r.limiters[clientID]; exists {
3858
return limiter
3959
}
4060

61+
// Prevent memory leaks by enforcing a maximum number of limiters
62+
if len(r.limiters) >= r.maxLimiters {
63+
// Simple eviction strategy: remove one random entry
64+
// For production, consider using LRU or similar algorithm
65+
for k := range r.limiters {
66+
delete(r.limiters, k)
67+
break
68+
}
69+
}
70+
4171
limiter := rate.NewLimiter(r.rate, r.burst)
4272
r.limiters[clientID] = limiter
4373

@@ -52,23 +82,56 @@ func (r *RateLimiter) GetLimiter(clientID string) *rate.Limiter {
5282
return limiter
5383
}
5484

85+
// DefaultTrustedProxies returns a list of commonly trusted proxy IPs
86+
func DefaultTrustedProxies() []string {
87+
return []string{"127.0.0.1", "::1"}
88+
}
89+
5590
// RateLimiterInterceptor applies rate limiting
5691
func (r *RateLimiter) RateLimiterInterceptor(
5792
ctx context.Context,
5893
req interface{},
5994
info *grpc.UnaryServerInfo,
6095
handler grpc.UnaryHandler) (interface{}, error) {
6196

62-
// Extract client identifier (can use IP or API key)
63-
// Extract the client's IP address from the context using the peer package.
6497
p, ok := peer.FromContext(ctx)
98+
if !ok {
99+
return nil, status.Errorf(codes.Internal, "could not determine peer")
100+
}
101+
102+
peerIP, _, err := net.SplitHostPort(p.Addr.String())
103+
if err != nil {
104+
return nil, status.Errorf(codes.Internal, "invalid peer address: %v", err)
105+
}
106+
65107
var clientID string
66-
if ok {
67-
clientID = p.Addr.String()
68-
// Optionally, further parse clientID if needed (e.g., remove port information)
69-
} else {
70-
clientID = "unknown"
108+
109+
// Only trust proxy headers if the request is from a trusted proxy
110+
if r.trustedProxies[peerIP] {
111+
if md, ok := metadata.FromIncomingContext(ctx); ok {
112+
// Check X-Forwarded-For (may be a comma-separated list)
113+
if xff := md.Get("x-forwarded-for"); len(xff) > 0 && xff[0] != "" {
114+
ips := strings.Split(xff[0], ",")
115+
if len(ips) > 0 && strings.TrimSpace(ips[0]) != "" {
116+
cleanIP := strings.TrimSpace(ips[0])
117+
if isValidIP(cleanIP) {
118+
clientID = cleanIP
119+
}
120+
}
121+
} else if xri := md.Get("x-real-ip"); len(xri) > 0 && xri[0] != "" {
122+
cleanIP := strings.TrimSpace(xri[0])
123+
if isValidIP(cleanIP) {
124+
clientID = cleanIP
125+
}
126+
}
127+
}
128+
}
129+
130+
// If no trusted clientID was found from headers, use the peer's IP
131+
if clientID == "" {
132+
clientID = peerIP
71133
}
134+
72135
limiter := r.GetLimiter(clientID)
73136
if !limiter.Allow() {
74137
return nil, status.Errorf(codes.ResourceExhausted, "Too many requests, slow down")

0 commit comments

Comments
 (0)