diff --git a/Cargo.lock b/Cargo.lock index 368e6fc56..f74c12d1e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5529,6 +5529,7 @@ dependencies = [ "rand 0.8.5", "regex", "reqwest 0.13.1", + "ring", "ruma", "rustls", "rustyline-async", diff --git a/src/api/client/mod.rs b/src/api/client/mod.rs index 96ae1dc1b..1a5352e17 100644 --- a/src/api/client/mod.rs +++ b/src/api/client/mod.rs @@ -15,6 +15,7 @@ pub(super) mod media; pub(super) mod media_legacy; pub(super) mod membership; pub(super) mod message; +pub(super) mod oidc; pub(super) mod openid; pub(super) mod presence; pub(super) mod profile; @@ -63,6 +64,7 @@ pub(super) use media::*; pub(super) use media_legacy::*; pub(super) use membership::*; pub(super) use message::*; +pub(super) use oidc::*; pub(super) use openid::*; pub(super) use presence::*; pub(super) use profile::*; diff --git a/src/api/client/oidc.rs b/src/api/client/oidc.rs new file mode 100644 index 000000000..629ad921c --- /dev/null +++ b/src/api/client/oidc.rs @@ -0,0 +1,518 @@ +use std::time::SystemTime; + +use axum::{ + Json, + extract::State, + response::{IntoResponse, Redirect}, +}; +use axum_extra::{ + TypedHeader, + headers::{Authorization, authorization::Bearer}, +}; +use http::StatusCode; +use ruma::OwnedDeviceId; +use serde::{Deserialize, Serialize}; +use tuwunel_core::{Err, Result, err, info, utils}; +use tuwunel_service::{ + oauth::oidc_server::{ + DcrRequest, IdTokenClaims, OidcAuthRequest, OidcServer, ProviderMetadata, + }, + users::device::generate_refresh_token, +}; + +const OIDC_REQ_ID_LENGTH: usize = 32; + +#[derive(Serialize)] +struct AuthIssuerResponse { + issuer: String, +} + +pub(crate) async fn auth_issuer_route( + State(services): State, +) -> Result { + let issuer = oidc_issuer_url(&services)?; + Ok(Json(AuthIssuerResponse { issuer })) +} + +pub(crate) async fn openid_configuration_route( + State(services): State, +) -> Result { + Ok(Json(oidc_metadata(&services)?)) +} + +fn oidc_metadata(services: &tuwunel_service::Services) -> Result { + let issuer = oidc_issuer_url(services)?; + let base = issuer.trim_end_matches('/').to_owned(); + + Ok(ProviderMetadata { + issuer, + authorization_endpoint: format!("{base}/_tuwunel/oidc/authorize"), + token_endpoint: format!("{base}/_tuwunel/oidc/token"), + registration_endpoint: Some(format!("{base}/_tuwunel/oidc/registration")), + revocation_endpoint: Some(format!("{base}/_tuwunel/oidc/revoke")), + jwks_uri: format!("{base}/_tuwunel/oidc/jwks"), + userinfo_endpoint: Some(format!("{base}/_tuwunel/oidc/userinfo")), + account_management_uri: Some(format!("{base}/_tuwunel/oidc/account")), + account_management_actions_supported: Some(vec![ + "org.matrix.profile".to_owned(), + "org.matrix.sessions_list".to_owned(), + "org.matrix.session_view".to_owned(), + "org.matrix.session_end".to_owned(), + "org.matrix.cross_signing_reset".to_owned(), + ]), + response_types_supported: vec!["code".to_owned()], + response_modes_supported: Some(vec!["query".to_owned(), "fragment".to_owned()]), + grant_types_supported: Some(vec![ + "authorization_code".to_owned(), + "refresh_token".to_owned(), + ]), + code_challenge_methods_supported: Some(vec!["S256".to_owned()]), + token_endpoint_auth_methods_supported: Some(vec![ + "none".to_owned(), + "client_secret_basic".to_owned(), + "client_secret_post".to_owned(), + ]), + scopes_supported: Some(vec![ + "openid".to_owned(), + "urn:matrix:org.matrix.msc2967.client:api:*".to_owned(), + "urn:matrix:org.matrix.msc2967.client:device:*".to_owned(), + ]), + subject_types_supported: Some(vec!["public".to_owned()]), + id_token_signing_alg_values_supported: Some(vec!["ES256".to_owned()]), + prompt_values_supported: Some(vec!["create".to_owned()]), + claim_types_supported: Some(vec!["normal".to_owned()]), + claims_supported: Some(vec![ + "iss".to_owned(), + "sub".to_owned(), + "aud".to_owned(), + "exp".to_owned(), + "iat".to_owned(), + "nonce".to_owned(), + ]), + }) +} + +pub(crate) async fn registration_route( + State(services): State, + Json(body): Json, +) -> Result { + let Ok(oidc) = get_oidc_server(&services) else { + return Err!(Request(NotFound("OIDC server not configured"))); + }; + + if body.redirect_uris.is_empty() { + return Err!(Request(InvalidParam("redirect_uris must not be empty"))); + } + + let reg = oidc.register_client(body)?; + info!( + "OIDC client registered: {} ({})", + reg.client_id, + reg.client_name.as_deref().unwrap_or("unnamed") + ); + + Ok(( + StatusCode::CREATED, + Json(serde_json::json!({ + "client_id": reg.client_id, + "client_id_issued_at": reg.registered_at, + "redirect_uris": reg.redirect_uris, + "client_name": reg.client_name, + "client_uri": reg.client_uri, + "logo_uri": reg.logo_uri, + "contacts": reg.contacts, + "token_endpoint_auth_method": reg.token_endpoint_auth_method, + "grant_types": reg.grant_types, + "response_types": reg.response_types, + "application_type": reg.application_type, + "policy_uri": reg.policy_uri, + "tos_uri": reg.tos_uri, + "software_id": reg.software_id, + "software_version": reg.software_version, + })), + )) +} + +#[derive(Debug, Deserialize)] +pub(crate) struct AuthorizeParams { + client_id: String, + redirect_uri: String, + response_type: String, + scope: String, + state: Option, + nonce: Option, + code_challenge: Option, + code_challenge_method: Option, + #[serde(default, rename = "prompt")] + _prompt: Option, +} + +pub(crate) async fn authorize_route( + State(services): State, + request: axum::extract::Request, +) -> Result { + let params: AuthorizeParams = + serde_html_form::from_str(request.uri().query().unwrap_or_default())?; + let Ok(oidc) = get_oidc_server(&services) else { + return Err!(Request(NotFound("OIDC server not configured"))); + }; + + if params.response_type != "code" { + return Err!(Request(InvalidParam("Only response_type=code is supported"))); + } + + oidc.validate_redirect_uri(¶ms.client_id, ¶ms.redirect_uri) + .await?; + + let req_id = utils::random_string(OIDC_REQ_ID_LENGTH); + let now = SystemTime::now(); + + oidc.store_auth_request(&req_id, &OidcAuthRequest { + client_id: params.client_id, + redirect_uri: params.redirect_uri, + scope: params.scope, + state: params.state, + nonce: params.nonce, + code_challenge: params.code_challenge, + code_challenge_method: params.code_challenge_method, + created_at: now, + expires_at: now + .checked_add(OidcServer::auth_request_lifetime()) + .unwrap_or(now), + }); + + let default_idp = services + .config + .identity_provider + .values() + .find(|idp| idp.default) + .or_else(|| services.config.identity_provider.values().next()) + .ok_or_else(|| err!(Config("identity_provider", "No identity provider configured")))?; + let idp_id = default_idp.id(); + + let base = oidc_issuer_url(&services)?; + let base = base.trim_end_matches('/'); + + let mut complete_url = url::Url::parse(&format!("{base}/_tuwunel/oidc/_complete")) + .map_err(|_| err!(error!("Failed to build complete URL")))?; + complete_url + .query_pairs_mut() + .append_pair("oidc_req_id", &req_id); + + let mut sso_url = + url::Url::parse(&format!("{base}/_matrix/client/v3/login/sso/redirect/{idp_id}")) + .map_err(|_| err!(error!("Failed to build SSO URL")))?; + sso_url + .query_pairs_mut() + .append_pair("redirectUrl", complete_url.as_str()); + + Ok(Redirect::temporary(sso_url.as_str())) +} + +#[derive(Debug, Deserialize)] +pub(crate) struct CompleteParams { + oidc_req_id: String, + #[serde(rename = "loginToken")] + login_token: String, +} + +pub(crate) async fn complete_route( + State(services): State, + request: axum::extract::Request, +) -> Result { + let params: CompleteParams = + serde_html_form::from_str(request.uri().query().unwrap_or_default())?; + let Ok(oidc) = get_oidc_server(&services) else { + return Err!(Request(NotFound("OIDC server not configured"))); + }; + + let user_id = services + .users + .find_from_login_token(¶ms.login_token) + .await + .map_err(|_| err!(Request(Forbidden("Invalid or expired login token"))))?; + let auth_req = oidc + .take_auth_request(¶ms.oidc_req_id) + .await?; + let code = oidc.create_auth_code(&auth_req, user_id); + + let mut redirect_url = url::Url::parse(&auth_req.redirect_uri) + .map_err(|_| err!(Request(InvalidParam("Invalid redirect_uri"))))?; + redirect_url + .query_pairs_mut() + .append_pair("code", &code); + if let Some(state) = &auth_req.state { + redirect_url + .query_pairs_mut() + .append_pair("state", state); + } + + Ok(Redirect::temporary(redirect_url.as_str())) +} + +#[derive(Debug, Deserialize)] +pub(crate) struct TokenRequest { + grant_type: String, + code: Option, + redirect_uri: Option, + client_id: Option, + code_verifier: Option, + refresh_token: Option, + #[serde(rename = "scope")] + _scope: Option, +} + +pub(crate) async fn token_route( + State(services): State, + axum::extract::Form(body): axum::extract::Form, +) -> impl IntoResponse { + match body.grant_type.as_str() { + | "authorization_code" => token_authorization_code(&services, &body) + .await + .unwrap_or_else(|e| { + oauth_error(StatusCode::INTERNAL_SERVER_ERROR, "server_error", &e.to_string()) + }), + | "refresh_token" => token_refresh(&services, &body) + .await + .unwrap_or_else(|e| { + oauth_error(StatusCode::INTERNAL_SERVER_ERROR, "server_error", &e.to_string()) + }), + | _ => oauth_error( + StatusCode::BAD_REQUEST, + "unsupported_grant_type", + "Unsupported grant_type", + ), + } +} + +async fn token_authorization_code( + services: &tuwunel_service::Services, + body: &TokenRequest, +) -> Result> { + let code = body + .code + .as_deref() + .ok_or_else(|| err!(Request(InvalidParam("code is required"))))?; + let redirect_uri = body + .redirect_uri + .as_deref() + .ok_or_else(|| err!(Request(InvalidParam("redirect_uri is required"))))?; + let client_id = body + .client_id + .as_deref() + .ok_or_else(|| err!(Request(InvalidParam("client_id is required"))))?; + + let oidc = get_oidc_server(services)?; + let session = oidc + .exchange_auth_code(code, client_id, redirect_uri, body.code_verifier.as_deref()) + .await?; + + let user_id = &session.user_id; + let (access_token, expires_in) = services.users.generate_access_token(true); + let refresh_token = generate_refresh_token(); + + let client_name = oidc + .get_client(client_id) + .await + .ok() + .and_then(|c| c.client_name); + let device_display_name = client_name.as_deref().unwrap_or("OIDC Client"); + let device_id: Option = + extract_device_id(&session.scope).map(OwnedDeviceId::from); + let device_id = services + .users + .create_device( + user_id, + device_id.as_deref(), + (Some(&access_token), expires_in), + Some(&refresh_token), + Some(device_display_name), + None, + ) + .await?; + + info!("{user_id} logged in via OIDC (device {device_id})"); + + let id_token = if session.scope.contains("openid") { + let now = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + let issuer = oidc_issuer_url(services)?; + let claims = IdTokenClaims { + iss: issuer, + sub: user_id.to_string(), + aud: client_id.to_owned(), + exp: now.saturating_add(3600), + iat: now, + nonce: session.nonce, + at_hash: Some(OidcServer::at_hash(&access_token)), + }; + Some(oidc.sign_id_token(&claims)?) + } else { + None + }; + + let mut response = serde_json::json!({ + "access_token": access_token, + "token_type": "Bearer", + "scope": session.scope, + "refresh_token": refresh_token, + }); + if let Some(expires_in) = expires_in { + response["expires_in"] = serde_json::json!(expires_in.as_secs()); + } + if let Some(id_token) = id_token { + response["id_token"] = serde_json::json!(id_token); + } + + Ok(Json(response).into_response()) +} + +async fn token_refresh( + services: &tuwunel_service::Services, + body: &TokenRequest, +) -> Result> { + let refresh_token = body + .refresh_token + .as_deref() + .ok_or_else(|| err!(Request(InvalidParam("refresh_token is required"))))?; + let (user_id, device_id, _) = services + .users + .find_from_token(refresh_token) + .await + .map_err(|_| err!(Request(Forbidden("Invalid refresh token"))))?; + + let (new_access_token, expires_in) = services.users.generate_access_token(true); + let new_refresh_token = generate_refresh_token(); + services + .users + .set_access_token( + &user_id, + &device_id, + &new_access_token, + expires_in, + Some(&new_refresh_token), + ) + .await?; + + let mut response = serde_json::json!({ + "access_token": new_access_token, + "token_type": "Bearer", + "refresh_token": new_refresh_token, + }); + if let Some(expires_in) = expires_in { + response["expires_in"] = serde_json::json!(expires_in.as_secs()); + } + + Ok(Json(response).into_response()) +} + +#[derive(Debug, Deserialize)] +pub(crate) struct RevokeRequest { + token: String, + #[serde(default, rename = "token_type_hint")] + _token_type_hint: Option, +} + +pub(crate) async fn revoke_route( + State(services): State, + axum::extract::Form(body): axum::extract::Form, +) -> Result { + if let Ok((user_id, device_id, _)) = services.users.find_from_token(&body.token).await { + services + .users + .remove_device(&user_id, &device_id) + .await; + } + Ok(Json(serde_json::json!({}))) +} + +pub(crate) async fn jwks_route( + State(services): State, +) -> Result { + let oidc = get_oidc_server(&services)?; + Ok(Json(oidc.jwks())) +} + +pub(crate) async fn userinfo_route( + State(services): State, + TypedHeader(Authorization(bearer)): TypedHeader>, +) -> Result { + let token = bearer.token(); + let Ok((user_id, _device_id, _expires)) = services.users.find_from_token(token).await else { + return Err!(Request(Unauthorized("Invalid access token"))); + }; + let displayname = services.users.displayname(&user_id).await.ok(); + let avatar_url = services.users.avatar_url(&user_id).await.ok(); + Ok(Json(serde_json::json!({ + "sub": user_id.to_string(), + "name": displayname, + "picture": avatar_url, + }))) +} + +pub(crate) async fn account_route( + State(services): State, +) -> Result { + // Redirect to the identity provider's panel where users can manage + // their account, sessions, devices, and profile. + let idp = services + .config + .identity_provider + .values() + .find(|idp| idp.default) + .or_else(|| services.config.identity_provider.values().next()) + .ok_or_else(|| err!(Config("identity_provider", "No identity provider configured")))?; + + let panel_url = idp.issuer_url.as_ref().ok_or_else(|| { + err!(Config("issuer_url", "issuer_url is required for account management redirect")) + })?; + + Ok(Redirect::temporary(panel_url.as_str())) +} + +fn oauth_error( + status: StatusCode, + error: &str, + description: &str, +) -> http::Response { + ( + status, + Json(serde_json::json!({ + "error": error, + "error_description": description, + })), + ) + .into_response() +} + +fn get_oidc_server(services: &tuwunel_service::Services) -> Result<&OidcServer> { + services + .oauth + .oidc_server + .as_deref() + .ok_or_else(|| err!(Request(NotFound("OIDC server not configured")))) +} + +fn oidc_issuer_url(services: &tuwunel_service::Services) -> Result { + services + .config + .well_known + .client + .as_ref() + .map(|url| { + let s = url.to_string(); + if s.ends_with('/') { s } else { s + "/" } + }) + .ok_or_else(|| { + err!(Config("well_known.client", "well_known.client must be set for OIDC server")) + }) +} + +fn extract_device_id(scope: &str) -> Option { + scope + .split_whitespace() + .find_map(|s| s.strip_prefix("urn:matrix:org.matrix.msc2967.client:device:")) + .map(ToOwned::to_owned) +} diff --git a/src/api/client/versions.rs b/src/api/client/versions.rs index 203d34d06..b019aa427 100644 --- a/src/api/client/versions.rs +++ b/src/api/client/versions.rs @@ -51,7 +51,7 @@ static VERSIONS: [&str; 17] = [ "v1.15", /* custom profile fields */ ]; -static UNSTABLE_FEATURES: [&str; 18] = [ +static UNSTABLE_FEATURES: [&str; 22] = [ "org.matrix.e2e_cross_signing", // private read receipts (https://github.com/matrix-org/matrix-spec-proposals/pull/2285) "org.matrix.msc2285.stable", @@ -86,4 +86,12 @@ static UNSTABLE_FEATURES: [&str; 18] = [ "org.matrix.simplified_msc3575", // Allow room moderators to view redacted event content (https://github.com/matrix-org/matrix-spec-proposals/pull/2815) "fi.mau.msc2815", + // OIDC-native auth: authorization code grant (https://github.com/matrix-org/matrix-spec-proposals/pull/2964) + "org.matrix.msc2964", + // OIDC-native auth: auth issuer discovery (https://github.com/matrix-org/matrix-spec-proposals/pull/2965) + "org.matrix.msc2965", + // OIDC-native auth: dynamic client registration (https://github.com/matrix-org/matrix-spec-proposals/pull/2966) + "org.matrix.msc2966", + // OIDC-native auth: API scopes (https://github.com/matrix-org/matrix-spec-proposals/pull/2967) + "org.matrix.msc2967", ]; diff --git a/src/api/router.rs b/src/api/router.rs index fd29acab4..8a1e8435e 100644 --- a/src/api/router.rs +++ b/src/api/router.rs @@ -198,6 +198,20 @@ pub fn build(router: Router, server: &Server) -> Router { .ruma_route(&client::well_known_support) .ruma_route(&client::well_known_client) .route("/_tuwunel/server_version", get(client::tuwunel_server_version)) + // OIDC server endpoints (next-gen auth, MSC2965/2964/2966/2967) + .route("/_matrix/client/unstable/org.matrix.msc2965/auth_issuer", get(client::auth_issuer_route)) + .route("/_matrix/client/v1/auth_issuer", get(client::auth_issuer_route)) + .route("/_matrix/client/unstable/org.matrix.msc2965/auth_metadata", get(client::openid_configuration_route)) + .route("/_matrix/client/v1/auth_metadata", get(client::openid_configuration_route)) + .route("/.well-known/openid-configuration", get(client::openid_configuration_route)) + .route("/_tuwunel/oidc/registration", post(client::registration_route)) + .route("/_tuwunel/oidc/authorize", get(client::authorize_route)) + .route("/_tuwunel/oidc/_complete", get(client::complete_route)) + .route("/_tuwunel/oidc/token", post(client::token_route)) + .route("/_tuwunel/oidc/revoke", post(client::revoke_route)) + .route("/_tuwunel/oidc/jwks", get(client::jwks_route)) + .route("/_tuwunel/oidc/userinfo", get(client::userinfo_route)) + .route("/_tuwunel/oidc/account", get(client::account_route)) .ruma_route(&client::room_initial_sync_route); // SS endpoints not related to federation diff --git a/src/database/maps.rs b/src/database/maps.rs index 5ec270bbd..9578f745b 100644 --- a/src/database/maps.rs +++ b/src/database/maps.rs @@ -145,6 +145,22 @@ pub(super) static MAPS: &[Descriptor] = &[ name: "oauthuniqid_oauthid", ..descriptor::RANDOM_SMALL }, + Descriptor { + name: "oidc_signingkey", + ..descriptor::RANDOM_SMALL + }, + Descriptor { + name: "oidcclientid_registration", + ..descriptor::RANDOM_SMALL + }, + Descriptor { + name: "oidccode_authsession", + ..descriptor::RANDOM_SMALL + }, + Descriptor { + name: "oidcreqid_authrequest", + ..descriptor::RANDOM_SMALL + }, Descriptor { name: "onetimekeyid_onetimekeys", ..descriptor::RANDOM_SMALL diff --git a/src/service/Cargo.toml b/src/service/Cargo.toml index 1e23bc912..2fd20f97c 100644 --- a/src/service/Cargo.toml +++ b/src/service/Cargo.toml @@ -104,6 +104,7 @@ lru-cache.workspace = true rand.workspace = true regex.workspace = true reqwest.workspace = true +ring.workspace = true ruma.workspace = true rustls.workspace = true rustyline-async.workspace = true diff --git a/src/service/oauth/mod.rs b/src/service/oauth/mod.rs index 3cd4894cb..5ddd4b8bb 100644 --- a/src/service/oauth/mod.rs +++ b/src/service/oauth/mod.rs @@ -1,3 +1,4 @@ +pub mod oidc_server; pub mod providers; pub mod sessions; pub mod user_info; @@ -14,13 +15,15 @@ use ruma::UserId; use serde::Serialize; use serde_json::Value as JsonValue; use tuwunel_core::{ - Err, Result, err, implement, + Err, Result, err, implement, info, utils::{hash::sha256, result::LogErr, stream::ReadyExt}, + warn, }; use url::Url; -use self::{providers::Providers, sessions::Sessions}; +use self::{oidc_server::OidcServer, providers::Providers, sessions::Sessions}; pub use self::{ + oidc_server::ProviderMetadata, providers::{Provider, ProviderId}, sessions::{CODE_VERIFIER_LENGTH, SESSION_ID_LENGTH, Session, SessionId}, user_info::UserInfo, @@ -31,16 +34,36 @@ pub struct Service { services: SelfServices, pub providers: Arc, pub sessions: Arc, + /// OIDC server (authorization server) for next-gen Matrix auth. + /// Only initialized when identity providers are configured. + pub oidc_server: Option>, } impl crate::Service for Service { fn build(args: &crate::Args<'_>) -> Result> { let providers = Arc::new(Providers::build(args)); let sessions = Arc::new(Sessions::build(args, providers.clone())); + + let oidc_server = if !args.server.config.identity_provider.is_empty() + || args.server.config.well_known.client.is_some() + { + if args.server.config.identity_provider.is_empty() { + warn!( + "OIDC server enabled (well_known.client is set) but no identity_provider \ + configured; authorization flow will not work" + ); + } + info!("Initializing OIDC server for next-gen auth (MSC2965)"); + Some(Arc::new(OidcServer::build(args)?)) + } else { + None + }; + Ok(Arc::new(Self { services: args.services.clone(), sessions, providers, + oidc_server, })) } diff --git a/src/service/oauth/oidc_server.rs b/src/service/oauth/oidc_server.rs new file mode 100644 index 000000000..6d6e9aa9a --- /dev/null +++ b/src/service/oauth/oidc_server.rs @@ -0,0 +1,431 @@ +use std::{ + net::IpAddr, + sync::Arc, + time::{Duration, SystemTime}, +}; + +use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD as b64}; +use ring::{ + rand::SystemRandom, + signature::{self, EcdsaKeyPair, KeyPair}, +}; +use ruma::OwnedUserId; +use serde::{Deserialize, Serialize}; +use tuwunel_core::{Err, Result, err, info, jwt, utils}; +use tuwunel_database::{Cbor, Deserialized, Map}; + +const AUTH_CODE_LENGTH: usize = 64; +const OIDC_CLIENT_ID_LENGTH: usize = 32; +const AUTH_CODE_LIFETIME: Duration = Duration::from_mins(10); +const AUTH_REQUEST_LIFETIME: Duration = Duration::from_mins(10); +const SIGNING_KEY_DB_KEY: &str = "oidc_signing_key"; + +pub struct OidcServer { + db: Data, + signing_key_der: Vec, + jwk: serde_json::Value, + key_id: String, +} + +struct Data { + oidc_signingkey: Arc, + oidcclientid_registration: Arc, + oidccode_authsession: Arc, + oidcreqid_authrequest: Arc, +} + +#[derive(Debug, Deserialize)] +pub struct DcrRequest { + pub redirect_uris: Vec, + pub client_name: Option, + pub client_uri: Option, + pub logo_uri: Option, + #[serde(default)] + pub contacts: Vec, + pub token_endpoint_auth_method: Option, + pub grant_types: Option>, + pub response_types: Option>, + pub application_type: Option, + pub policy_uri: Option, + pub tos_uri: Option, + pub software_id: Option, + pub software_version: Option, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct OidcClientRegistration { + pub client_id: String, + pub redirect_uris: Vec, + pub client_name: Option, + pub client_uri: Option, + pub logo_uri: Option, + pub contacts: Vec, + pub token_endpoint_auth_method: String, + pub grant_types: Vec, + pub response_types: Vec, + pub application_type: Option, + pub policy_uri: Option, + pub tos_uri: Option, + pub software_id: Option, + pub software_version: Option, + pub registered_at: u64, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct AuthCodeSession { + pub code: String, + pub client_id: String, + pub redirect_uri: String, + pub scope: String, + pub state: Option, + pub nonce: Option, + pub code_challenge: Option, + pub code_challenge_method: Option, + pub user_id: OwnedUserId, + pub created_at: SystemTime, + pub expires_at: SystemTime, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct OidcAuthRequest { + pub client_id: String, + pub redirect_uri: String, + pub scope: String, + pub state: Option, + pub nonce: Option, + pub code_challenge: Option, + pub code_challenge_method: Option, + pub created_at: SystemTime, + pub expires_at: SystemTime, +} + +#[derive(Serialize, Deserialize)] +struct SigningKeyData { + key_der: Vec, + key_id: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ProviderMetadata { + pub issuer: String, + pub authorization_endpoint: String, + pub token_endpoint: String, + pub registration_endpoint: Option, + pub revocation_endpoint: Option, + pub jwks_uri: String, + pub userinfo_endpoint: Option, + pub account_management_uri: Option, + pub account_management_actions_supported: Option>, + pub response_types_supported: Vec, + pub response_modes_supported: Option>, + pub grant_types_supported: Option>, + pub code_challenge_methods_supported: Option>, + pub token_endpoint_auth_methods_supported: Option>, + pub scopes_supported: Option>, + pub subject_types_supported: Option>, + pub id_token_signing_alg_values_supported: Option>, + pub prompt_values_supported: Option>, + pub claim_types_supported: Option>, + pub claims_supported: Option>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct IdTokenClaims { + pub iss: String, + pub sub: String, + pub aud: String, + pub exp: u64, + pub iat: u64, + #[serde(skip_serializing_if = "Option::is_none")] + pub nonce: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub at_hash: Option, +} + +impl ProviderMetadata { + #[must_use] + pub fn into_json(self) -> serde_json::Value { + serde_json::to_value(self).expect("ProviderMetadata serialization") + } +} + +impl OidcServer { + pub(crate) fn build(args: &crate::Args<'_>) -> Result { + let db = Data { + oidc_signingkey: args.db["oidc_signingkey"].clone(), + oidcclientid_registration: args.db["oidcclientid_registration"].clone(), + oidccode_authsession: args.db["oidccode_authsession"].clone(), + oidcreqid_authrequest: args.db["oidcreqid_authrequest"].clone(), + }; + + let (signing_key_der, key_id) = match db + .oidc_signingkey + .get_blocking(SIGNING_KEY_DB_KEY) + .and_then(|handle| { + handle + .deserialized::>() + .map(|cbor| cbor.0) + }) { + | Ok(data) => { + info!("Loaded existing OIDC signing key (kid={})", data.key_id); + (data.key_der, data.key_id) + }, + | Err(_) => { + let (key_der, key_id) = Self::generate_signing_key()?; + info!("Generated new OIDC signing key (kid={key_id})"); + let data = SigningKeyData { + key_der: key_der.clone(), + key_id: key_id.clone(), + }; + db.oidc_signingkey + .raw_put(SIGNING_KEY_DB_KEY, Cbor(&data)); + (key_der, key_id) + }, + }; + + let jwk = Self::build_jwk(&signing_key_der, &key_id)?; + Ok(Self { db, signing_key_der, jwk, key_id }) + } + + fn generate_signing_key() -> Result<(Vec, String)> { + let rng = SystemRandom::new(); + let alg = &signature::ECDSA_P256_SHA256_FIXED_SIGNING; + let pkcs8 = EcdsaKeyPair::generate_pkcs8(alg, &rng) + .map_err(|e| err!(error!("Failed to generate ECDSA key: {e}")))?; + let key_id = utils::random_string(16); + Ok((pkcs8.as_ref().to_vec(), key_id)) + } + + fn build_jwk(signing_key_der: &[u8], key_id: &str) -> Result { + let rng = SystemRandom::new(); + let alg = &signature::ECDSA_P256_SHA256_FIXED_SIGNING; + let key_pair = EcdsaKeyPair::from_pkcs8(alg, signing_key_der, &rng) + .map_err(|e| err!(error!("Failed to load ECDSA key: {e}")))?; + let public_bytes = key_pair.public_key().as_ref(); + let x = b64.encode(&public_bytes[1..33]); + let y = b64.encode(&public_bytes[33..65]); + Ok(serde_json::json!({ + "kty": "EC", + "crv": "P-256", + "use": "sig", + "alg": "ES256", + "kid": key_id, + "x": x, + "y": y, + })) + } + + pub fn register_client(&self, request: DcrRequest) -> Result { + let client_id = utils::random_string(OIDC_CLIENT_ID_LENGTH); + let auth_method = request + .token_endpoint_auth_method + .unwrap_or_else(|| "none".to_owned()); + let registration = OidcClientRegistration { + client_id: client_id.clone(), + redirect_uris: request.redirect_uris, + client_name: request.client_name, + client_uri: request.client_uri, + logo_uri: request.logo_uri, + contacts: request.contacts, + token_endpoint_auth_method: auth_method, + grant_types: request.grant_types.unwrap_or_else(|| { + vec!["authorization_code".to_owned(), "refresh_token".to_owned()] + }), + response_types: request + .response_types + .unwrap_or_else(|| vec!["code".to_owned()]), + application_type: request.application_type, + policy_uri: request.policy_uri, + tos_uri: request.tos_uri, + software_id: request.software_id, + software_version: request.software_version, + registered_at: SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(), + }; + self.db + .oidcclientid_registration + .raw_put(&*client_id, Cbor(®istration)); + Ok(registration) + } + + pub async fn get_client(&self, client_id: &str) -> Result { + self.db + .oidcclientid_registration + .get(client_id) + .await + .deserialized::>() + .map(|cbor: Cbor| cbor.0) + .map_err(|_| err!(Request(NotFound("Unknown client_id")))) + } + + pub async fn validate_redirect_uri(&self, client_id: &str, redirect_uri: &str) -> Result { + let client = self.get_client(client_id).await?; + if client + .redirect_uris + .iter() + .any(|uri| redirect_uri_matches(uri, redirect_uri)) + { + Ok(()) + } else { + Err!(Request(InvalidParam("redirect_uri not registered for this client"))) + } + } + + pub fn store_auth_request(&self, req_id: &str, request: &OidcAuthRequest) { + self.db + .oidcreqid_authrequest + .raw_put(req_id, Cbor(request)); + } + + pub async fn take_auth_request(&self, req_id: &str) -> Result { + let request: OidcAuthRequest = self + .db + .oidcreqid_authrequest + .get(req_id) + .await + .deserialized::>() + .map(|cbor: Cbor| cbor.0) + .map_err(|_| err!(Request(NotFound("Unknown or expired authorization request"))))?; + self.db.oidcreqid_authrequest.remove(req_id); + if SystemTime::now() > request.expires_at { + return Err!(Request(NotFound("Authorization request has expired"))); + } + Ok(request) + } + + #[must_use] + pub fn create_auth_code(&self, auth_req: &OidcAuthRequest, user_id: OwnedUserId) -> String { + let code = utils::random_string(AUTH_CODE_LENGTH); + let now = SystemTime::now(); + let session = AuthCodeSession { + code: code.clone(), + client_id: auth_req.client_id.clone(), + redirect_uri: auth_req.redirect_uri.clone(), + scope: auth_req.scope.clone(), + state: auth_req.state.clone(), + nonce: auth_req.nonce.clone(), + code_challenge: auth_req.code_challenge.clone(), + code_challenge_method: auth_req.code_challenge_method.clone(), + user_id, + created_at: now, + expires_at: now.checked_add(AUTH_CODE_LIFETIME).unwrap_or(now), + }; + self.db + .oidccode_authsession + .raw_put(&*code, Cbor(&session)); + code + } + + pub async fn exchange_auth_code( + &self, + code: &str, + client_id: &str, + redirect_uri: &str, + code_verifier: Option<&str>, + ) -> Result { + let session: AuthCodeSession = self + .db + .oidccode_authsession + .get(code) + .await + .deserialized::>() + .map(|cbor: Cbor| cbor.0) + .map_err(|_| err!(Request(Forbidden("Invalid or expired authorization code"))))?; + self.db.oidccode_authsession.remove(code); + if SystemTime::now() > session.expires_at { + return Err!(Request(Forbidden("Authorization code has expired"))); + } + if session.client_id != client_id { + return Err!(Request(Forbidden("client_id mismatch"))); + } + if session.redirect_uri != redirect_uri { + return Err!(Request(Forbidden("redirect_uri mismatch"))); + } + + if let Some(challenge) = &session.code_challenge { + let Some(verifier) = code_verifier else { + return Err!(Request(Forbidden("code_verifier required for PKCE"))); + }; + Self::validate_code_verifier(verifier)?; + let method = session + .code_challenge_method + .as_deref() + .unwrap_or("S256"); + let computed = match method { + | "S256" => { + let hash = utils::hash::sha256::hash(verifier.as_bytes()); + b64.encode(hash) + }, + | "plain" => verifier.to_owned(), + | _ => return Err!(Request(InvalidParam("Unsupported code_challenge_method"))), + }; + if computed != *challenge { + return Err!(Request(Forbidden("PKCE verification failed"))); + } + } + + Ok(session) + } + + /// Validate code_verifier per RFC 7636 Section 4.1: must be 43-128 + /// characters using only unreserved characters [A-Z] / [a-z] / [0-9] / + /// "-" / "." / "_" / "~". + fn validate_code_verifier(verifier: &str) -> Result { + if !(43..=128).contains(&verifier.len()) { + return Err!(Request(InvalidParam("code_verifier must be 43-128 characters"))); + } + if !verifier.bytes().all(|b| { + b.is_ascii_alphanumeric() || b == b'-' || b == b'.' || b == b'_' || b == b'~' + }) { + return Err!(Request(InvalidParam("code_verifier contains invalid characters"))); + } + Ok(()) + } + + pub fn sign_id_token(&self, claims: &IdTokenClaims) -> Result { + let mut header = jwt::Header::new(jwt::Algorithm::ES256); + header.kid = Some(self.key_id.clone()); + let key = jwt::EncodingKey::from_ec_der(&self.signing_key_der); + jwt::encode(&header, claims, &key) + .map_err(|e| err!(error!("Failed to sign ID token: {e}"))) + } + + #[must_use] + pub fn jwks(&self) -> serde_json::Value { + serde_json::json!({ + "keys": [self.jwk.clone()], + }) + } + + #[must_use] + pub fn at_hash(access_token: &str) -> String { + let hash = utils::hash::sha256::hash(access_token.as_bytes()); + b64.encode(&hash[..16]) + } + + #[must_use] + pub fn auth_request_lifetime() -> Duration { AUTH_REQUEST_LIFETIME } +} + +fn redirect_uri_matches(registered: &str, requested: &str) -> bool { + if registered == requested { + return true; + } + + match (url::Url::parse(registered), url::Url::parse(requested)) { + | (Ok(reg), Ok(req)) if is_loopback_redirect(®) && is_loopback_redirect(&req) => + reg.scheme() == req.scheme() + && reg.host_str() == req.host_str() + && reg.path() == req.path() + && reg.query() == req.query() + && reg.fragment() == req.fragment(), + | _ => false, + } +} + +fn is_loopback_redirect(uri: &url::Url) -> bool { + uri.scheme() == "http" + && matches!(uri.host_str().and_then(|h| h.parse::().ok()), Some(ip) if ip.is_loopback()) +}