diff --git a/internal/api/middleware.go b/internal/api/middleware.go index e41ae80c3..ab28a8c58 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -62,18 +62,51 @@ func (f *FunctionHooks) UnmarshalJSON(b []byte) error { var emailRateLimitCounter = observability.ObtainMetricCounter("gotrue_email_rate_limit_counter", "Number of times an email rate limit has been triggered") func (a *API) performRateLimiting(lmt *limiter.Limiter, req *http.Request) error { - if limitHeader := a.config.RateLimitHeader; limitHeader != "" { - key := req.Header.Get(limitHeader) - - if key == "" { - log := observability.GetLogEntry(req).Entry - log.WithField("header", limitHeader).Warn("request does not have a value for the rate limiting header, rate limiting is not applied") - } else { - err := tollbooth.LimitByKeys(lmt, []string{key}) - if err != nil { - return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverRequestRateLimit, "Request rate limit reached") - } - } + limitHeader := a.config.RateLimitHeader + + // If no rate limit header was set, ignore rate limiting + if limitHeader == "" { + return nil + } + + valuesStr := req.Header.Get(limitHeader) + + // If a rate limit header was set, but has no value, ignore rate limiting but warn with an error + if valuesStr == "" { + log := observability.GetLogEntry(req).Entry + log.WithField("header", limitHeader).Warn("request does not have a value for the rate limiting header, rate limiting is not applied") + + return nil + } + + // According to RFC 7230 section 3.2.2, multiple headers with the same name are equivalent + // to a single header with that name where each value is separated by a comma and whitespace. + // + // Note that there is some ambiguity in RFC 7230 where section 3.2.4 states that + // header field values (which can contain commas) are processed independently of the header + // field name, and thus it is not always clear if a comma is a list delimiter or simply par + // of a single value. + // + // Given that this function is primarily for use with headers like X-Forwarded-For which + // vendors generally combine into comma-separated lists, we opt for the simpler approach + // here and split the header value by commas before taking the first value. + values := strings.SplitN(valuesStr, ",", 2) + + // We will always get at least one value back, so this operation is safe + key := strings.TrimSpace(values[0]) + + // If the rate limit header has at least one value, but the first value is all whitespace, return a warning. + // This will happen if the header is something like "X-Foo-Bar: ,baz". + if key == "" { + log := observability.GetLogEntry(req).Entry + log.WithField("header", limitHeader).Warn("first rate limit header value is empty, rate limiting is not applied") + + return nil + } + + // Otherwise, apply rate limiting based on the first rate limit header value + if err := tollbooth.LimitByKeys(lmt, []string{key}); err != nil { + return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverRequestRateLimit, "Request rate limit reached") } return nil diff --git a/internal/api/middleware_test.go b/internal/api/middleware_test.go index 68dbabb7c..34db4c0f0 100644 --- a/internal/api/middleware_test.go +++ b/internal/api/middleware_test.go @@ -415,6 +415,108 @@ func TestTimeoutResponseWriter(t *testing.T) { require.Equal(t, w1.Result(), w2.Result()) } +func (ts *MiddlewareTestSuite) TestPerformRateLimiting() { + ts.Config.RateLimitHeader = "X-Test-Perform-Rate-Limiting" + + tests := []struct { + name string + headerValues []string + expError error + }{ + { + name: "no value", + headerValues: []string{ + "", + "", + }, + expError: nil, + }, + { + name: "single end user value", + headerValues: []string{ + "192.168.1.100", + "192.168.1.100", + }, + expError: apierrors.NewTooManyRequestsError( + apierrors.ErrorCodeOverRequestRateLimit, + "Request rate limit reached", + ), + }, + { + name: "same end user value, multiple proxies", + headerValues: []string{ + "2600:cafe:beef::1,192.168.1.100", + "2600:cafe:beef::1,192.168.1.200", + }, + expError: apierrors.NewTooManyRequestsError( + apierrors.ErrorCodeOverRequestRateLimit, + "Request rate limit reached", + ), + }, + { + name: "multiple end user values, single proxy", + headerValues: []string{ + "2600:cafe:beef::1,192.168.1.100", + "3700:dead:abcd::2,192.168.1.100", + }, + expError: nil, + }, + { + name: "same end user value, multiple proxies, with whitespace", + headerValues: []string{ + "2600:cafe:beef::1 ,192.168.1.100", + "2600:cafe:beef::1 , 192.168.1.200", + }, + expError: apierrors.NewTooManyRequestsError( + apierrors.ErrorCodeOverRequestRateLimit, + "Request rate limit reached", + ), + }, + { + name: "empty header, all whitespace", + headerValues: []string{ + " ", + }, + expError: nil, + }, + { + name: "empty first key, no whitespace", + headerValues: []string{ + ",192.168.1.100", + }, + expError: nil, + }, + { + name: "empty first key, with whitespace", + headerValues: []string{ + " ,192.168.1.100", + }, + expError: nil, + }, + } + + for _, tt := range tests { + // Trigger a rate limiting error if we see the same end-user key twice in the same + // test case + lmt := tollbooth.NewLimiter( + 1, + &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Hour, + }, + ) + + var obsError error + + for _, h := range tt.headerValues { + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + req.Header.Add(ts.Config.RateLimitHeader, h) + obsError = ts.API.performRateLimiting(lmt, req) + } + + require.ErrorIs(ts.T(), obsError, tt.expError, "error for test '%s'", tt.name) + } +} + func (ts *MiddlewareTestSuite) TestLimitHandler() { ts.Config.RateLimitHeader = "X-Rate-Limit" lmt := tollbooth.NewLimiter(5, &limiter.ExpirableOptions{