Skip to content
42 changes: 18 additions & 24 deletions internal/api/oauthserver/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
)

// validateRedirectURIList validates a list of redirect URIs
func validateRedirectURIList(redirectURIs []string, required bool) error {
func (s *Server) validateRedirectURIList(redirectURIs []string, required bool) error {
if required && len(redirectURIs) == 0 {
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "redirect_uris is required")
}
Expand All @@ -33,7 +33,7 @@ func validateRedirectURIList(redirectURIs []string, required bool) error {
}

for _, uri := range redirectURIs {
if err := validateRedirectURI(uri); err != nil {
if err := s.validateRedirectURI(uri); err != nil {
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "invalid redirect_uri '%s': %v", uri, err)
}
}
Expand Down Expand Up @@ -117,9 +117,9 @@ type OAuthServerClientRegisterParams struct {
}

// validate validates the OAuth client registration parameters
func (p *OAuthServerClientRegisterParams) validate() error {
func (p *OAuthServerClientRegisterParams) validate(s *Server) error {
// Validate redirect URIs (required for registration)
if err := validateRedirectURIList(p.RedirectURIs, true); err != nil {
if err := s.validateRedirectURIList(p.RedirectURIs, true); err != nil {
return err
}

Expand Down Expand Up @@ -170,8 +170,8 @@ func (p *OAuthServerClientRegisterParams) validate() error {
return nil
}

// validateRedirectURI validates OAuth 2.1 redirect URIs
func validateRedirectURI(uri string) error {
// validateRedirectURI validates OAuth 2.1 redirect URIs against the allow list configuration
func (s *Server) validateRedirectURI(uri string) error {
if uri == "" {
return fmt.Errorf("redirect URI cannot be empty")
}
Expand All @@ -181,27 +181,21 @@ func validateRedirectURI(uri string) error {
return fmt.Errorf("invalid URL format")
}

// Must have scheme and host
if parsedURL.Scheme == "" || parsedURL.Host == "" {
return fmt.Errorf("must have scheme and host")
}

// Check scheme requirements
if parsedURL.Scheme == "http" {
// HTTP only allowed for localhost
host := parsedURL.Hostname()
if host != "localhost" && host != "127.0.0.1" {
return fmt.Errorf("HTTP scheme only allowed for localhost")
}
} else if parsedURL.Scheme != "https" {
return fmt.Errorf("scheme must be HTTPS or HTTP (localhost only)")
// Must have scheme
if parsedURL.Scheme == "" {
return fmt.Errorf("must have scheme")
}

// Must not have fragment
if parsedURL.Fragment != "" {
return fmt.Errorf("fragment not allowed in redirect URI")
}

// Check against the URI allow list (supports custom schemes like cursor://, exp://, etc.)
if !utilities.IsRedirectURLValid(s.config, uri) {
return fmt.Errorf("redirect URI not allowed by configuration")
}

return nil
}

Expand Down Expand Up @@ -235,7 +229,7 @@ func ValidateClientSecret(providedSecret, storedHash string) bool {
// registerOAuthServerClient creates a new OAuth server client with generated credentials
func (s *Server) registerOAuthServerClient(ctx context.Context, params *OAuthServerClientRegisterParams) (*models.OAuthServerClient, string, error) {
// Validate all parameters
if err := params.validate(); err != nil {
if err := params.validate(s); err != nil {
return nil, "", err
}

Expand Down Expand Up @@ -362,10 +356,10 @@ func (p *OAuthServerClientUpdateParams) isEmpty() bool {
}

// validate validates the OAuth client update parameters
func (p *OAuthServerClientUpdateParams) validate() error {
func (p *OAuthServerClientUpdateParams) validate(s *Server) error {
// Validate redirect URIs if provided
if p.RedirectURIs != nil {
if err := validateRedirectURIList(*p.RedirectURIs, false); err != nil {
if err := s.validateRedirectURIList(*p.RedirectURIs, false); err != nil {
return err
}
}
Expand Down Expand Up @@ -404,7 +398,7 @@ func (p *OAuthServerClientUpdateParams) validate() error {
// updateOAuthServerClient updates an existing OAuth client
func (s *Server) updateOAuthServerClient(ctx context.Context, clientID uuid.UUID, params *OAuthServerClientUpdateParams) (*models.OAuthServerClient, error) {
// Validate all parameters
if err := params.validate(); err != nil {
if err := params.validate(s); err != nil {
return nil, err
}

Expand Down
120 changes: 106 additions & 14 deletions internal/api/oauthserver/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"testing"

"github.com/gobwas/glob"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
Expand Down Expand Up @@ -59,6 +60,15 @@ func (ts *OAuthServiceTestSuite) SetupTest() {
// Enable OAuth server and dynamic client registration for tests
ts.Config.OAuthServer.Enabled = true
ts.Config.OAuthServer.AllowDynamicRegistration = true

// Add test URIs to allow list for testing
ts.Config.URIAllowList = append(ts.Config.URIAllowList, "https://example.com/**", "https://app.example.com/**")
// Rebuild the allow list map
ts.Config.URIAllowListMap = make(map[string]glob.Glob)
for _, uri := range ts.Config.URIAllowList {
g := glob.MustCompile(uri, '.', '/')
ts.Config.URIAllowListMap[uri] = g
}
}

// Helper function to create test OAuth client
Expand Down Expand Up @@ -275,18 +285,6 @@ func (ts *OAuthServiceTestSuite) TestRedirectURIValidation() {
shouldError: true,
errorMsg: "redirect URI cannot be empty",
},
{
name: "Invalid scheme",
uri: "ftp://example.com/callback",
shouldError: true,
errorMsg: "scheme must be HTTPS or HTTP (localhost only)",
},
{
name: "Invalid HTTP non-localhost",
uri: "http://example.com/callback",
shouldError: true,
errorMsg: "HTTP scheme only allowed for localhost",
},
{
name: "Invalid URI with fragment",
uri: "https://example.com/callback#fragment",
Expand All @@ -297,13 +295,19 @@ func (ts *OAuthServiceTestSuite) TestRedirectURIValidation() {
name: "Invalid URI format",
uri: "not-a-uri",
shouldError: true,
errorMsg: "must have scheme and host",
errorMsg: "must have scheme",
},
{
name: "URI not in allow list",
uri: "ftp://example.com/callback",
shouldError: true,
errorMsg: "redirect URI not allowed by configuration",
},
}

for _, tc := range testCases {
ts.T().Run(tc.name, func(t *testing.T) {
err := validateRedirectURI(tc.uri)
err := ts.Server.validateRedirectURI(tc.uri)
if tc.shouldError {
assert.Error(t, err)
if tc.errorMsg != "" {
Expand Down Expand Up @@ -336,3 +340,91 @@ func (ts *OAuthServiceTestSuite) TestGrantTypeDefaults() {
assert.Contains(ts.T(), grantTypes, "refresh_token")
assert.Len(ts.T(), grantTypes, 2)
}

func (ts *OAuthServiceTestSuite) TestCustomURISchemes() {
// Test custom URI schemes when they're in the allow list
// This tests the fix for issue #2285

// Save original allow list
originalAllowList := ts.Config.URIAllowList
originalAllowListMap := ts.Config.URIAllowListMap
defer func() {
ts.Config.URIAllowList = originalAllowList
ts.Config.URIAllowListMap = originalAllowListMap
}()

// Configure allow list with custom schemes (keep existing + add custom)
ts.Config.URIAllowList = append([]string{}, originalAllowList...)
ts.Config.URIAllowList = append(ts.Config.URIAllowList, "cursor://**", "com.example.app://**", "exp://**")
// Rebuild the allow list map
ts.Config.URIAllowListMap = make(map[string]glob.Glob)
for _, uri := range ts.Config.URIAllowList {
g := glob.MustCompile(uri, '.', '/')
ts.Config.URIAllowListMap[uri] = g
}

ctx := context.Background()

// Test 1: cursor:// scheme (for Cursor IDE)
params := &OAuthServerClientRegisterParams{
ClientName: "Cursor IDE",
RedirectURIs: []string{"cursor://anysphere.cursor-mcp/callback"},
RegistrationType: "dynamic",
}

client, secret, err := ts.Server.registerOAuthServerClient(ctx, params)
require.NoError(ts.T(), err, "Should allow cursor:// scheme when in allow list")
require.NotNil(ts.T(), client)
require.NotEmpty(ts.T(), secret)
assert.Equal(ts.T(), "Cursor IDE", *client.ClientName)
assert.Equal(ts.T(), []string{"cursor://anysphere.cursor-mcp/callback"}, client.GetRedirectURIs())

// Test 2: Mobile app scheme (com.example.app://)
params = &OAuthServerClientRegisterParams{
ClientName: "Mobile App",
RedirectURIs: []string{"com.example.app://sign-in/v2"},
RegistrationType: "dynamic",
}

client, secret, err = ts.Server.registerOAuthServerClient(ctx, params)
require.NoError(ts.T(), err, "Should allow com.example.app:// scheme when in allow list")
require.NotNil(ts.T(), client)
require.NotEmpty(ts.T(), secret)
assert.Equal(ts.T(), "Mobile App", *client.ClientName)

// Test 3: Expo scheme (exp://)
params = &OAuthServerClientRegisterParams{
ClientName: "Expo App",
RedirectURIs: []string{"exp://192.168.1.1:19000/--/auth/callback"},
RegistrationType: "dynamic",
}

client, secret, err = ts.Server.registerOAuthServerClient(ctx, params)
require.NoError(ts.T(), err, "Should allow exp:// scheme when in allow list")
require.NotNil(ts.T(), client)
require.NotEmpty(ts.T(), secret)

// Test 4: Unauthorized custom scheme should fail
params = &OAuthServerClientRegisterParams{
ClientName: "Malicious App",
RedirectURIs: []string{"malicious://attack"},
RegistrationType: "dynamic",
}

_, _, err = ts.Server.registerOAuthServerClient(ctx, params)
assert.Error(ts.T(), err, "Should reject custom scheme not in allow list")
assert.Contains(ts.T(), err.Error(), "redirect URI not allowed by configuration")

// Test 5: Mix of custom and standard schemes
params = &OAuthServerClientRegisterParams{
ClientName: "Multi-Platform App",
RedirectURIs: []string{"https://example.com/callback", "cursor://app/callback"},
RegistrationType: "dynamic",
}

client, secret, err = ts.Server.registerOAuthServerClient(ctx, params)
require.NoError(ts.T(), err, "Should allow mix of standard and custom schemes")
require.NotNil(ts.T(), client)
require.NotEmpty(ts.T(), secret)
assert.Len(ts.T(), client.GetRedirectURIs(), 2)
}