@@ -2,42 +2,72 @@ package middlewares
22
33import (
44 "context"
5+ "net"
6+ "strings"
57 "sync"
68 "time"
79
810 "golang.org/x/time/rate"
911 "google.golang.org/grpc"
1012 "google.golang.org/grpc/codes"
13+ "google.golang.org/grpc/metadata"
1114 "google.golang.org/grpc/peer"
1215 "google.golang.org/grpc/status"
1316)
1417
1518// RateLimiter structure
1619type RateLimiter struct {
17- mu sync.Mutex
18- limiters map [string ]* rate.Limiter
19- rate rate.Limit
20- burst int
20+ mu sync.Mutex
21+ limiters map [string ]* rate.Limiter
22+ rate rate.Limit
23+ burst int
24+ trustedProxies map [string ]bool
25+ maxLimiters int
2126}
2227
23- // NewRateLimiter initializes a rate limiter
24- func NewRateLimiter (r rate.Limit , b int ) * RateLimiter {
28+ // NewRateLimiter initializes a rate limiter with configurable rate, burst, and trusted proxies
29+ func NewRateLimiter (r rate.Limit , b int , proxies []string ) * RateLimiter {
30+ trusted := make (map [string ]bool )
31+ for _ , proxy := range proxies {
32+ if isValidIP (proxy ) {
33+ trusted [proxy ] = true
34+ }
35+ }
36+
2537 return & RateLimiter {
26- limiters : make (map [string ]* rate.Limiter ),
27- rate : r ,
28- burst : b ,
38+ limiters : make (map [string ]* rate.Limiter ),
39+ rate : r ,
40+ burst : b ,
41+ trustedProxies : trusted ,
42+ maxLimiters : 10000 , // Default maximum number of limiters to prevent memory leaks
2943 }
3044}
3145
32- // getLimiter gets or creates a rate limiter for a specific client
46+ // isValidIP checks if a string is a valid IP address
47+ func isValidIP (ip string ) bool {
48+ return net .ParseIP (ip ) != nil
49+ }
50+
51+ // GetLimiter gets or creates a rate limiter for a specific client
3352func (r * RateLimiter ) GetLimiter (clientID string ) * rate.Limiter {
3453 r .mu .Lock ()
3554 defer r .mu .Unlock ()
3655
56+ // Return existing limiter if it exists
3757 if limiter , exists := r .limiters [clientID ]; exists {
3858 return limiter
3959 }
4060
61+ // Prevent memory leaks by enforcing a maximum number of limiters
62+ if len (r .limiters ) >= r .maxLimiters {
63+ // Simple eviction strategy: remove one random entry
64+ // For production, consider using LRU or similar algorithm
65+ for k := range r .limiters {
66+ delete (r .limiters , k )
67+ break
68+ }
69+ }
70+
4171 limiter := rate .NewLimiter (r .rate , r .burst )
4272 r .limiters [clientID ] = limiter
4373
@@ -52,23 +82,56 @@ func (r *RateLimiter) GetLimiter(clientID string) *rate.Limiter {
5282 return limiter
5383}
5484
85+ // DefaultTrustedProxies returns a list of commonly trusted proxy IPs
86+ func DefaultTrustedProxies () []string {
87+ return []string {"127.0.0.1" , "::1" }
88+ }
89+
5590// RateLimiterInterceptor applies rate limiting
5691func (r * RateLimiter ) RateLimiterInterceptor (
5792 ctx context.Context ,
5893 req interface {},
5994 info * grpc.UnaryServerInfo ,
6095 handler grpc.UnaryHandler ) (interface {}, error ) {
6196
62- // Extract client identifier (can use IP or API key)
63- // Extract the client's IP address from the context using the peer package.
6497 p , ok := peer .FromContext (ctx )
98+ if ! ok {
99+ return nil , status .Errorf (codes .Internal , "could not determine peer" )
100+ }
101+
102+ peerIP , _ , err := net .SplitHostPort (p .Addr .String ())
103+ if err != nil {
104+ return nil , status .Errorf (codes .Internal , "invalid peer address: %v" , err )
105+ }
106+
65107 var clientID string
66- if ok {
67- clientID = p .Addr .String ()
68- // Optionally, further parse clientID if needed (e.g., remove port information)
69- } else {
70- clientID = "unknown"
108+
109+ // Only trust proxy headers if the request is from a trusted proxy
110+ if r .trustedProxies [peerIP ] {
111+ if md , ok := metadata .FromIncomingContext (ctx ); ok {
112+ // Check X-Forwarded-For (may be a comma-separated list)
113+ if xff := md .Get ("x-forwarded-for" ); len (xff ) > 0 && xff [0 ] != "" {
114+ ips := strings .Split (xff [0 ], "," )
115+ if len (ips ) > 0 && strings .TrimSpace (ips [0 ]) != "" {
116+ cleanIP := strings .TrimSpace (ips [0 ])
117+ if isValidIP (cleanIP ) {
118+ clientID = cleanIP
119+ }
120+ }
121+ } else if xri := md .Get ("x-real-ip" ); len (xri ) > 0 && xri [0 ] != "" {
122+ cleanIP := strings .TrimSpace (xri [0 ])
123+ if isValidIP (cleanIP ) {
124+ clientID = cleanIP
125+ }
126+ }
127+ }
128+ }
129+
130+ // If no trusted clientID was found from headers, use the peer's IP
131+ if clientID == "" {
132+ clientID = peerIP
71133 }
134+
72135 limiter := r .GetLimiter (clientID )
73136 if ! limiter .Allow () {
74137 return nil , status .Errorf (codes .ResourceExhausted , "Too many requests, slow down" )
0 commit comments