Skip to content

Commit d425db7

Browse files
authored
Feat: support oauth (#202)
* support oauth * lint * avoid hard-coded port * address comments
1 parent 12c6f63 commit d425db7

File tree

13 files changed

+4391
-5
lines changed

13 files changed

+4391
-5
lines changed

docs/oauth-authentication.md

Lines changed: 474 additions & 0 deletions
Large diffs are not rendered by default.

internal/auth/oauth/endpoints.go

Lines changed: 1021 additions & 0 deletions
Large diffs are not rendered by default.

internal/auth/oauth/endpoints_test.go

Lines changed: 601 additions & 0 deletions
Large diffs are not rendered by default.

internal/auth/oauth/middleware.go

Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
package oauth
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"log"
8+
"net/http"
9+
"strings"
10+
11+
"github.com/Azure/aks-mcp/internal/auth"
12+
)
13+
14+
// contextKey is a custom type for context keys to avoid collisions
15+
type contextKey string
16+
17+
const tokenInfoKey contextKey = "token_info"
18+
19+
// AuthMiddleware handles OAuth authentication for HTTP requests
20+
type AuthMiddleware struct {
21+
provider *AzureOAuthProvider
22+
serverURL string
23+
}
24+
25+
// setCORSHeaders sets CORS headers for OAuth endpoints with origin whitelisting
26+
func (m *AuthMiddleware) setCORSHeaders(w http.ResponseWriter, r *http.Request) {
27+
requestOrigin := r.Header.Get("Origin")
28+
29+
// Check if the request origin is in the allowed list
30+
var allowedOrigin string
31+
for _, allowed := range m.provider.config.AllowedOrigins {
32+
if requestOrigin == allowed {
33+
allowedOrigin = requestOrigin
34+
break
35+
}
36+
}
37+
38+
// Only set CORS headers if origin is allowed
39+
if allowedOrigin != "" {
40+
w.Header().Set("Access-Control-Allow-Origin", allowedOrigin)
41+
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
42+
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, mcp-protocol-version")
43+
w.Header().Set("Access-Control-Max-Age", "86400") // 24 hours
44+
w.Header().Set("Access-Control-Allow-Credentials", "false")
45+
} else if requestOrigin != "" {
46+
log.Printf("CORS ERROR: Origin %s is not in the allowed list - cross-origin requests will be blocked for security", requestOrigin)
47+
}
48+
}
49+
50+
// NewAuthMiddleware creates a new authentication middleware
51+
func NewAuthMiddleware(provider *AzureOAuthProvider, serverURL string) *AuthMiddleware {
52+
return &AuthMiddleware{
53+
provider: provider,
54+
serverURL: serverURL,
55+
}
56+
}
57+
58+
// Middleware returns an HTTP middleware function for OAuth authentication
59+
func (m *AuthMiddleware) Middleware(next http.Handler) http.Handler {
60+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
61+
62+
// Skip authentication for specific endpoints
63+
if m.shouldSkipAuth(r) {
64+
log.Printf("Skipping auth for path: %s\n", r.URL.Path)
65+
next.ServeHTTP(w, r)
66+
return
67+
}
68+
69+
// Perform authentication
70+
authResult := m.authenticateRequest(r)
71+
72+
if !authResult.Authenticated {
73+
log.Printf("Authentication FAILED - handling error\n")
74+
m.handleAuthError(w, r, authResult)
75+
return
76+
}
77+
78+
// Add token info to request context
79+
ctx := context.WithValue(r.Context(), tokenInfoKey, authResult.TokenInfo)
80+
r = r.WithContext(ctx)
81+
82+
next.ServeHTTP(w, r)
83+
})
84+
}
85+
86+
// shouldSkipAuth determines if authentication should be skipped for this request
87+
func (m *AuthMiddleware) shouldSkipAuth(r *http.Request) bool {
88+
// Skip auth for OAuth metadata endpoints
89+
path := r.URL.Path
90+
91+
skipPaths := []string{
92+
"/.well-known/oauth-protected-resource",
93+
"/.well-known/oauth-authorization-server",
94+
"/.well-known/openid-configuration",
95+
"/oauth2/v2.0/authorize",
96+
"/oauth/register",
97+
"/oauth/callback",
98+
"/oauth2/v2.0/token",
99+
"/oauth/introspect",
100+
"/health",
101+
"/ping",
102+
}
103+
104+
for _, skipPath := range skipPaths {
105+
if path == skipPath {
106+
return true
107+
}
108+
}
109+
110+
return false
111+
}
112+
113+
// authenticateRequest performs OAuth authentication on the request
114+
func (m *AuthMiddleware) authenticateRequest(r *http.Request) *auth.AuthResult {
115+
// Extract Bearer token from Authorization header
116+
authHeader := r.Header.Get("Authorization")
117+
118+
if authHeader == "" {
119+
log.Printf("OAuth DEBUG - Missing authorization header for %s %s\n", r.Method, r.URL.Path)
120+
log.Printf("OAuth DEBUG - Request headers: %+v\n", r.Header)
121+
return &auth.AuthResult{
122+
Authenticated: false,
123+
Error: "missing authorization header",
124+
StatusCode: http.StatusUnauthorized,
125+
}
126+
}
127+
128+
// Check for Bearer token format
129+
const bearerPrefix = "Bearer "
130+
if !strings.HasPrefix(authHeader, bearerPrefix) {
131+
log.Printf("FAILED - Invalid authorization header format (missing Bearer prefix)\n")
132+
return &auth.AuthResult{
133+
Authenticated: false,
134+
Error: "invalid authorization header format",
135+
StatusCode: http.StatusUnauthorized,
136+
}
137+
}
138+
139+
token := strings.TrimPrefix(authHeader, bearerPrefix)
140+
if token == "" {
141+
log.Printf("FAILED - Empty bearer token\n")
142+
return &auth.AuthResult{
143+
Authenticated: false,
144+
Error: "empty bearer token",
145+
StatusCode: http.StatusUnauthorized,
146+
}
147+
}
148+
149+
// Basic JWT structure validation
150+
tokenParts := strings.Split(token, ".")
151+
if len(tokenParts) != 3 {
152+
log.Printf("FAILED - JWT structure validation (has %d parts, expected 3)\n", len(tokenParts))
153+
return &auth.AuthResult{
154+
Authenticated: false,
155+
Error: "invalid JWT structure",
156+
StatusCode: http.StatusUnauthorized,
157+
}
158+
}
159+
160+
// Validate the token
161+
tokenInfo, err := m.provider.ValidateToken(r.Context(), token)
162+
if err != nil {
163+
log.Printf("FAILED - Provider token validation failed: %v\n", err)
164+
return &auth.AuthResult{
165+
Authenticated: false,
166+
Error: fmt.Sprintf("token validation failed: %v", err),
167+
StatusCode: http.StatusUnauthorized,
168+
}
169+
}
170+
171+
// Validate required scopes - strict enforcement for security
172+
if !m.validateScopes(tokenInfo.Scope) {
173+
log.Printf("SCOPE ERROR: Token scopes %v don't match required scopes %v", tokenInfo.Scope, m.provider.config.RequiredScopes)
174+
return &auth.AuthResult{
175+
Authenticated: false,
176+
Error: "insufficient scope",
177+
StatusCode: http.StatusForbidden,
178+
}
179+
}
180+
181+
return &auth.AuthResult{
182+
Authenticated: true,
183+
TokenInfo: tokenInfo,
184+
StatusCode: http.StatusOK,
185+
}
186+
}
187+
188+
// validateScopes checks if the token has required scopes
189+
func (m *AuthMiddleware) validateScopes(tokenScopes []string) bool {
190+
requiredScopes := m.provider.config.RequiredScopes
191+
if len(requiredScopes) == 0 {
192+
return true // No scopes required
193+
}
194+
195+
// Check if token has at least one required scope
196+
for _, required := range requiredScopes {
197+
if m.hasScopePermission(required, tokenScopes) {
198+
return true
199+
}
200+
}
201+
202+
return false
203+
}
204+
205+
// hasScopePermission checks if the token scopes satisfy the required scope
206+
func (m *AuthMiddleware) hasScopePermission(requiredScope string, tokenScopes []string) bool {
207+
// Direct scope match
208+
for _, tokenScope := range tokenScopes {
209+
if tokenScope == requiredScope {
210+
return true
211+
}
212+
}
213+
214+
// Azure resource scope mapping
215+
azureResourceMappings := map[string][]string{
216+
"https://management.azure.com/.default": {
217+
"user_impersonation",
218+
"https://management.azure.com/user_impersonation",
219+
"https://management.azure.com/.default",
220+
"https://management.core.windows.net/",
221+
"https://management.azure.com/",
222+
},
223+
"https://graph.microsoft.com/.default": {
224+
"User.Read",
225+
"https://graph.microsoft.com/User.Read",
226+
},
227+
}
228+
229+
if allowedScopes, exists := azureResourceMappings[requiredScope]; exists {
230+
for _, allowedScope := range allowedScopes {
231+
for _, tokenScope := range tokenScopes {
232+
if tokenScope == allowedScope {
233+
return true
234+
}
235+
}
236+
}
237+
}
238+
239+
return false
240+
}
241+
242+
// handleAuthError handles authentication errors
243+
func (m *AuthMiddleware) handleAuthError(w http.ResponseWriter, r *http.Request, authResult *auth.AuthResult) {
244+
// Set CORS headers
245+
m.setCORSHeaders(w, r)
246+
w.Header().Set("Content-Type", "application/json")
247+
248+
// Add WWW-Authenticate header for 401 responses (RFC 9728 Section 5.1)
249+
if authResult.StatusCode == http.StatusUnauthorized {
250+
// Build the resource metadata URL
251+
scheme := "http"
252+
if r.TLS != nil {
253+
scheme = "https"
254+
}
255+
host := r.Host
256+
if host == "" {
257+
host = r.URL.Host
258+
}
259+
serverURL := fmt.Sprintf("%s://%s", scheme, host)
260+
resourceMetadataURL := fmt.Sprintf("%s/.well-known/oauth-protected-resource", serverURL)
261+
262+
// RFC 9728 compliant WWW-Authenticate header
263+
wwwAuth := fmt.Sprintf(`Bearer realm="%s", resource_metadata="%s"`, serverURL, resourceMetadataURL)
264+
265+
// Add error information if available
266+
if authResult.Error != "" {
267+
wwwAuth += fmt.Sprintf(`, error="invalid_token", error_description="%s"`, authResult.Error)
268+
}
269+
270+
w.Header().Set("WWW-Authenticate", wwwAuth)
271+
}
272+
273+
w.WriteHeader(authResult.StatusCode)
274+
275+
errorResponse := map[string]interface{}{
276+
"error": getOAuthErrorCode(authResult.StatusCode),
277+
"error_description": authResult.Error,
278+
}
279+
280+
if err := json.NewEncoder(w).Encode(errorResponse); err != nil {
281+
log.Printf("MIDDLEWARE ERROR: Failed to encode error response: %v\n", err)
282+
} else {
283+
log.Printf("MIDDLEWARE ERROR: Error response sent\n")
284+
}
285+
}
286+
287+
// getOAuthErrorCode returns appropriate OAuth error code for HTTP status
288+
func getOAuthErrorCode(statusCode int) string {
289+
switch statusCode {
290+
case http.StatusUnauthorized:
291+
return "invalid_token"
292+
case http.StatusForbidden:
293+
return "insufficient_scope"
294+
case http.StatusBadRequest:
295+
return "invalid_request"
296+
default:
297+
return "server_error"
298+
}
299+
}

0 commit comments

Comments
 (0)