Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 41 additions & 12 deletions internal/api/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,18 +62,47 @@ func (f *FunctionHooks) UnmarshalJSON(b []byte) error {
var emailRateLimitCounter = observability.ObtainMetricCounter("gotrue_email_rate_limit_counter", "Number of times an email rate limit has been triggered")

func (a *API) performRateLimiting(lmt *limiter.Limiter, req *http.Request) error {
if limitHeader := a.config.RateLimitHeader; limitHeader != "" {
key := req.Header.Get(limitHeader)

if key == "" {
log := observability.GetLogEntry(req).Entry
log.WithField("header", limitHeader).Warn("request does not have a value for the rate limiting header, rate limiting is not applied")
} else {
err := tollbooth.LimitByKeys(lmt, []string{key})
if err != nil {
return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverRequestRateLimit, "Request rate limit reached")
}
}
limitHeader := a.config.RateLimitHeader

// If no rate limit header was set, ignore rate limiting
if limitHeader == "" {
return nil
}

valuesStr := req.Header.Get(limitHeader)

// If a rate limit header was set, but has no value, ignore rate limiting but warn with an error
if valuesStr == "" {
log := observability.GetLogEntry(req).Entry
log.WithField("header", limitHeader).Warn("request does not have a value for the rate limiting header, rate limiting is not applied")

return nil
}

// According to RFC 7230 section 3.2.2, multiple headers with the same name are equivalent
// to a single header with that name where each value is separated by a comma and whitespace.
//
// Note that there is some ambiguity in RFC 7230 where section 3.2.4 states that
// header field values (which can contain commas) are processed independently of the header
// field name, and thus it is not always clear if a comma is a list delimiter or simply par
// of a single value.
//
// Given that this function is primarily for use with headers like X-Forwarded-For which
// vendors generally combine into comma-separated lists, we opt for the simpler approach
// here and split the header value by commas before taking the first value.
values := strings.SplitN(valuesStr, ",", 2)

// We will always get at least one value back, so this operation is safe
key := strings.TrimSpace(values[0])

// If the rate limit header has at least one value, but the first value is all whitespace, return an error
if key == "" {
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Invalid rate limit header value")
}

// Otherwise, apply rate limiting based on the first rate limit header value
if err := tollbooth.LimitByKeys(lmt, []string{key}); err != nil {
return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverRequestRateLimit, "Request rate limit reached")
}

return nil
Expand Down
111 changes: 111 additions & 0 deletions internal/api/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,117 @@ func TestTimeoutResponseWriter(t *testing.T) {
require.Equal(t, w1.Result(), w2.Result())
}

func (ts *MiddlewareTestSuite) TestPerformRateLimiting() {
ts.Config.RateLimitHeader = "X-Test-Perform-Rate-Limiting"

tests := []struct {
name string
headerValues []string
expError error
}{
{
name: "no value",
headerValues: []string{
"",
"",
},
expError: nil,
},
{
name: "single end user value",
headerValues: []string{
"192.168.1.100",
"192.168.1.100",
},
expError: apierrors.NewTooManyRequestsError(
apierrors.ErrorCodeOverRequestRateLimit,
"Request rate limit reached",
),
},
{
name: "same end user value, multiple proxies",
headerValues: []string{
"2600:cafe:beef::1,192.168.1.100",
"2600:cafe:beef::1,192.168.1.200",
},
expError: apierrors.NewTooManyRequestsError(
apierrors.ErrorCodeOverRequestRateLimit,
"Request rate limit reached",
),
},
{
name: "multiple end user values, single proxy",
headerValues: []string{
"2600:cafe:beef::1,192.168.1.100",
"3700:dead:abcd::2,192.168.1.100",
},
expError: nil,
},
{
name: "same end user value, multiple proxies, with whitespace",
headerValues: []string{
"2600:cafe:beef::1 ,192.168.1.100",
"2600:cafe:beef::1 , 192.168.1.200",
},
expError: apierrors.NewTooManyRequestsError(
apierrors.ErrorCodeOverRequestRateLimit,
"Request rate limit reached",
),
},
{
name: "malformed header, all whitespace",
headerValues: []string{
" ",
},
expError: apierrors.NewBadRequestError(
apierrors.ErrorCodeOverRequestRateLimit,
"Invalid rate limit header value",
),
},
{
name: "malformed header, no whitespace",
headerValues: []string{
",192.168.1.100",
},
expError: apierrors.NewBadRequestError(
apierrors.ErrorCodeOverRequestRateLimit,
"Invalid rate limit header value",
),
},
{
name: "malformed header, with whitespace",
headerValues: []string{
" ,192.168.1.100",
},
expError: apierrors.NewBadRequestError(
apierrors.ErrorCodeOverRequestRateLimit,
"Invalid rate limit header value",
),
},
}

for _, tt := range tests {
// Trigger a rate limiting error if we see the same end-user key twice in the same
// test case
lmt := tollbooth.NewLimiter(
1,
&limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
},
)

var obsError error

for _, h := range tt.headerValues {
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
req.Header.Add(ts.Config.RateLimitHeader, h)
obsError = ts.API.performRateLimiting(lmt, req)
}

require.ErrorIs(ts.T(), obsError, tt.expError, "error for test '%s'", tt.name)
}
}

func (ts *MiddlewareTestSuite) TestLimitHandler() {
ts.Config.RateLimitHeader = "X-Rate-Limit"
lmt := tollbooth.NewLimiter(5, &limiter.ExpirableOptions{
Expand Down