Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,12 @@ Enforce reauthentication on password update.

Use this to enable/disable anonymous sign-ins.

### IP address forwarding

`GOTRUE_SECURITY_SB_FORWARDED_FOR_ENABLED` - `bool`

Enable IP address forwarding using the `Sb-Forwarded-For` HTTP request header. When enabled, Auth will parse the first value of this header as an IP address and use it for IP address tracking and rate limiting. Make sure this header is fully trusted before enabling this feature by only passing it from trustworthy clients or proxies.

## Endpoints

Auth exposes the following endpoints:
Expand Down
12 changes: 11 additions & 1 deletion internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/supabase/auth/internal/mailer/templatemailer"
"github.com/supabase/auth/internal/models"
"github.com/supabase/auth/internal/observability"
"github.com/supabase/auth/internal/sbff"
"github.com/supabase/auth/internal/storage"
"github.com/supabase/auth/internal/tokens"
"github.com/supabase/auth/internal/utilities"
Expand Down Expand Up @@ -152,8 +153,17 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
r := newRouter()
r.UseBypass(observability.AddRequestID(globalConfig))
r.UseBypass(logger)
r.UseBypass(xffmw.Handler)
r.UseBypass(recoverer)
r.UseBypass(
sbff.Middleware(
&globalConfig.Security,
func(r *http.Request, err error) {
log := observability.GetLogEntry(r).Entry
log.WithField("error", err.Error()).Warn("error processing Sb-Forwarded-For")
},
),
)
r.UseBypass(xffmw.Handler)

if globalConfig.API.MaxRequestDuration > 0 {
r.UseBypass(timeoutMiddleware(globalConfig.API.MaxRequestDuration))
Expand Down
15 changes: 14 additions & 1 deletion internal/api/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/supabase/auth/internal/api/shared"
"github.com/supabase/auth/internal/models"
"github.com/supabase/auth/internal/observability"
"github.com/supabase/auth/internal/sbff"
"github.com/supabase/auth/internal/security"
"github.com/supabase/auth/internal/utilities"

Expand Down Expand Up @@ -61,7 +62,7 @@ func (f *FunctionHooks) UnmarshalJSON(b []byte) error {

var emailRateLimitCounter = observability.ObtainMetricCounter("gotrue_email_rate_limit_counter", "Number of times an email rate limit has been triggered")

func (a *API) performRateLimiting(lmt *limiter.Limiter, req *http.Request) error {
func (a *API) performRateLimitingWithHeader(lmt *limiter.Limiter, req *http.Request) error {
limitHeader := a.config.RateLimitHeader

// If no rate limit header was set, ignore rate limiting
Expand Down Expand Up @@ -112,6 +113,18 @@ func (a *API) performRateLimiting(lmt *limiter.Limiter, req *http.Request) error
return nil
}

func (a *API) performRateLimiting(lmt *limiter.Limiter, req *http.Request) error {
if sbffAddr, ok := sbff.GetIPAddress(req); ok {
if err := tollbooth.LimitByKeys(lmt, []string{sbffAddr}); err != nil {
return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverRequestRateLimit, "Request rate limit reached")
}

return nil
}

return a.performRateLimitingWithHeader(lmt, req)
}

func (a *API) limitHandler(lmt *limiter.Limiter) middlewareHandler {
return func(w http.ResponseWriter, req *http.Request) (context.Context, error) {
return req.Context(), a.performRateLimiting(lmt, req)
Expand Down
162 changes: 161 additions & 1 deletion internal/api/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/stretchr/testify/suite"
"github.com/supabase/auth/internal/api/apierrors"
"github.com/supabase/auth/internal/conf"
"github.com/supabase/auth/internal/sbff"
"github.com/supabase/auth/internal/storage"
)

Expand Down Expand Up @@ -415,7 +416,166 @@ func TestTimeoutResponseWriter(t *testing.T) {
require.Equal(t, w1.Result(), w2.Result())
}

func (ts *MiddlewareTestSuite) TestPerformRateLimiting() {
func (ts *MiddlewareTestSuite) TestPerformRateLimitingWithSBFF() {
origRateLimitHeader := ts.Config.RateLimitHeader
origSBFFEnabled := ts.Config.Security.SbForwardedForEnabled

defer func() {
ts.Config.RateLimitHeader = origRateLimitHeader
ts.Config.Security.SbForwardedForEnabled = origSBFFEnabled
}()

ts.Config.RateLimitHeader = "X-Test-Perform-Rate-Limiting"
ts.Config.Security.SbForwardedForEnabled = true

type headerSet struct {
rateLimiting string
sbForwardedFor string
}

testCases := []struct {
name string
headerValues []headerSet
expErr error
}{
{
name: "multiple SBFF values, single rate limiting value",
headerValues: []headerSet{
{
sbForwardedFor: "192.168.1.100",
rateLimiting: "60.60.60.60",
},
{
sbForwardedFor: "192.168.1.200",
rateLimiting: "60.60.60.60",
},
},
expErr: nil,
},
{
name: "single SBFF value, multiple rate limiting values",
headerValues: []headerSet{
{
sbForwardedFor: "192.168.1.100",
rateLimiting: "60.60.60.60",
},
{
sbForwardedFor: "192.168.1.100",
rateLimiting: "70.70.70.70",
},
},
expErr: apierrors.NewTooManyRequestsError(
apierrors.ErrorCodeOverRequestRateLimit,
"Request rate limit reached",
),
},
{
name: "no SBFF value, multiple rate limiting values",
headerValues: []headerSet{
{
sbForwardedFor: "",
rateLimiting: "60.60.60.60",
},
{
sbForwardedFor: "",
rateLimiting: "70.70.70.70",
},
},
expErr: nil,
},
{
name: "no SBFF value, single rate limiting value",
headerValues: []headerSet{
{
sbForwardedFor: "",
rateLimiting: "60.60.60.60",
},
{
sbForwardedFor: "",
rateLimiting: "60.60.60.60",
},
},
expErr: apierrors.NewTooManyRequestsError(
apierrors.ErrorCodeOverRequestRateLimit,
"Request rate limit reached",
),
},
{
name: "invalid SBFF value, multiple rate limiting values",
headerValues: []headerSet{
{
sbForwardedFor: "invalid",
rateLimiting: "60.60.60.60",
},
{
sbForwardedFor: "invalid",
rateLimiting: "70.70.70.70",
},
},
expErr: nil,
},
{
name: "invalid SBFF value, single rate limiting value",
headerValues: []headerSet{
{
sbForwardedFor: "invalid",
rateLimiting: "60.60.60.60",
},
{
sbForwardedFor: "invalid",
rateLimiting: "60.60.60.60",
},
},
expErr: apierrors.NewTooManyRequestsError(
apierrors.ErrorCodeOverRequestRateLimit,
"Request rate limit reached",
),
},
}

// This test uses the SBFF middleware to inject the Sb-Forwarded-For IP address value, then
// wraps a handler that calls performRateLimiting and stores the error value.
for _, tc := range testCases {
lmt := tollbooth.NewLimiter(
1,
&limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
},
)

var obsErr error

var handler http.HandlerFunc = func(rw http.ResponseWriter, r *http.Request) {
obsErr = ts.API.performRateLimiting(lmt, r)
}

errCallback := func(r *http.Request, err error) {
}

middleware := sbff.Middleware(&ts.Config.Security, errCallback)

wrappedHandler := middleware(handler)

for _, h := range tc.headerValues {
r := httptest.NewRequest(http.MethodGet, "http://localhost/", nil)

if h.rateLimiting != "" {
r.Header.Set(ts.Config.RateLimitHeader, h.rateLimiting)
}

if h.sbForwardedFor != "" {
r.Header.Set(sbff.HeaderName, h.sbForwardedFor)
}

wrappedHandler.ServeHTTP(nil, r)
}

require.ErrorIs(ts.T(), obsErr, tc.expErr)
}

}

func (ts *MiddlewareTestSuite) TestPerformRateLimitingWithHeader() {
ts.Config.RateLimitHeader = "X-Test-Perform-Rate-Limiting"

tests := []struct {
Expand Down
1 change: 1 addition & 0 deletions internal/conf/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,7 @@ type SecurityConfiguration struct {
RefreshTokenAllowReuse bool `json:"refresh_token_allow_reuse" split_words:"true"`
UpdatePasswordRequireReauthentication bool `json:"update_password_require_reauthentication" split_words:"true"`
ManualLinkingEnabled bool `json:"manual_linking_enabled" split_words:"true" default:"false"`
SbForwardedForEnabled bool `json:"sb_forwarded_for_enabled" split_words:"true" default:"false"`

DBEncryption DatabaseEncryptionConfiguration `json:"database_encryption" split_words:"true"`
}
Expand Down
94 changes: 94 additions & 0 deletions internal/sbff/sbff.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package sbff

import (
"context"
"errors"
"net"
"net/http"
"strings"

"github.com/supabase/auth/internal/conf"
)

// HeaderName is the Sb-Forwarded-For header name. It is all lowercase here as HTTP header names
// are not case-sensitive.
const HeaderName = "sb-forwarded-for"

var (
ctxKeySBFF = &struct{}{}

ErrHeaderNotFound = errors.New("Sb-Forwarded-For header not found")
ErrHeaderInvalid = errors.New("invalid Sb-Forwarded-For header value")
)

func parseSBFFHeader(headerVal string) (string, error) {
values := strings.SplitN(headerVal, ",", 2)
key := strings.TrimSpace(values[0])
if ipAddr := net.ParseIP(key); ipAddr != nil {
return ipAddr.String(), nil
}

return "", ErrHeaderInvalid
}

// GetIPAddress returns the value of the IP address in Sb-Forwarded-For as defined by
// SBForwardedForMiddleware. If no value is present in the request context, this function will
// return ("", false).
func GetIPAddress(r *http.Request) (addr string, found bool) {
if ipAddr, ok := r.Context().Value(ctxKeySBFF).(string); ok && ipAddr != "" {
return ipAddr, true
}

return "", false
}

// withIPAddress parses the Sb-Forwarded-For header and adds the leftmost value to the
// request context if it is a valid IP address, then returns a new request with modified context.
// If the leftmost value is not a valid IP address or the header is not set, this function returns
// an error.
func withIPAddress(r *http.Request) (*http.Request, error) {
headerVal := r.Header.Get(HeaderName)
if headerVal == "" {
return nil, ErrHeaderNotFound
}

parsedIPAddr, err := parseSBFFHeader(headerVal)
if err != nil {
return nil, err
}

ctx := r.Context()
newCtx := context.WithValue(ctx, ctxKeySBFF, parsedIPAddr)
out := r.WithContext(newCtx)

return out, nil
}

// Middleware returns a middleware function that parses the Sb-Forwarded-For header
// and adds the leftmost header value to the request context if GOTRUE_SECURITY_SB_FORWARDED_FOR_ENABLED
// is true and the value is a valid IP address.
func Middleware(cfg *conf.SecurityConfiguration, errCallback func(*http.Request, error)) func(http.Handler) http.Handler {
out := func(next http.Handler) http.Handler {
handlerFunc := func(rw http.ResponseWriter, r *http.Request) {
if !cfg.SbForwardedForEnabled {
next.ServeHTTP(rw, r)
return
}

reqWithSBFF, err := withIPAddress(r)
switch {
case err == nil:
next.ServeHTTP(rw, reqWithSBFF)
case errors.Is(err, ErrHeaderNotFound):
next.ServeHTTP(rw, r)
default:
errCallback(r, err)
next.ServeHTTP(rw, r)
}
}

return http.HandlerFunc(handlerFunc)
}

return out
}
Loading
Loading