diff --git a/src/api/client/session/mod.rs b/src/api/client/session/mod.rs index c4c55254c..f440629c1 100644 --- a/src/api/client/session/mod.rs +++ b/src/api/client/session/mod.rs @@ -29,7 +29,7 @@ use self::{ldap::ldap_login, password::password_login}; pub(crate) use self::{ logout::{logout_all_route, logout_route}, refresh::refresh_token_route, - sso::{sso_callback_route, sso_login_route, sso_login_with_provider_route}, + sso::{sso_callback_route, sso_fallback_route, sso_login_route, sso_login_with_provider_route}, token::login_token_route, }; use super::TOKEN_LENGTH; diff --git a/src/api/client/session/sso.rs b/src/api/client/session/sso.rs index 8021fde2f..9daea4362 100644 --- a/src/api/client/session/sso.rs +++ b/src/api/client/session/sso.rs @@ -9,6 +9,7 @@ use reqwest::header::{CONTENT_TYPE, HeaderValue}; use ruma::{ Mxc, OwnedMxcUri, OwnedRoomId, OwnedUserId, ServerName, UserId, api::client::session::{sso_callback, sso_login, sso_login_with_provider}, + api::client::uiaa::{AuthType, get_uiaa_fallback_page}, }; use serde::{Deserialize, Serialize}; use tuwunel_core::{ @@ -423,6 +424,46 @@ pub(crate) async fn sso_callback_route( .to_string() .into(); + if let Some(ref redirect_url) = session.redirect_url { + if redirect_url.scheme() == "uiaa" { + let uiaa_session_id = redirect_url.path(); + + // Find the UIAA session by its ID + let (db_user_id, device_id, mut uiaainfo) = services + .uiaa + .get_uiaa_session_by_session_id(uiaa_session_id) + .await + .ok_or_else(|| err!(Request(Forbidden("UIAA session not found."))))?; + + // SECURITY: Ensure the user authenticating via SSO is the owner of the UIAA session + if db_user_id != user_id { + return Err!(Request(Forbidden("UIAA session belongs to a different user."))); + } + + // Mark the SSO step as completed + if !uiaainfo.completed.contains(&AuthType::Sso) { + uiaainfo.completed.push(AuthType::Sso); + services.uiaa.update_uiaa_session( + &user_id, + &device_id, + uiaa_session_id, + Some(&uiaainfo), + ); + } + + // Redirect back to the fallback page to render the success HTML + let location = format!( + "/_matrix/client/v3/auth/m.login.sso/fallback/web?session={}", + uiaa_session_id + ); + + return Ok(sso_callback::unstable::Response { + location, + cookie: Some(cookie), + }); + } + } + // Determine the next provider to chain after this one. let next_idp_url = services .config @@ -727,3 +768,38 @@ fn parse_user_id(server_name: &ServerName, username: &str) -> Result, + body: Ruma, +) -> Result { + let session = &body.body.session; + + // Check if this UIAA session has already been completed via SSO + if let Some((_, _, uiaainfo)) = services.uiaa.get_uiaa_session_by_session_id(session).await { + if uiaainfo.completed.contains(&AuthType::Sso) { + let html = r#"Authentication Complete

Authentication Successful

You can safely close this window.

"#; + + return Ok(get_uiaa_fallback_page::v3::Response::html(html.as_bytes().to_vec())); + } + } + + // Session is not completed yet, show the prompt to continue + let url_str = format!("/_matrix/client/v3/login/sso/redirect?redirectUrl=uiaa:{}", session); + + let html = format!( + r#"Authentication Required
🛡️

Single Sign-On Required

To confirm this action, please re-authenticate with your Single Sign-On provider.

Continue with SSO
Security Notice: If you did not trigger this action, your account may be compromised.
"#, + url_str, + ); + + Ok(get_uiaa_fallback_page::v3::Response::html(html.into_bytes())) +} diff --git a/src/api/router.rs b/src/api/router.rs index fd29acab4..8455c886c 100644 --- a/src/api/router.rs +++ b/src/api/router.rs @@ -41,6 +41,7 @@ pub fn build(router: Router, server: &Server) -> Router { .ruma_route(&client::sso_login_route) .ruma_route(&client::sso_login_with_provider_route) .ruma_route(&client::sso_callback_route) + .ruma_route(&client::sso_fallback_route) .ruma_route(&client::whoami_route) .ruma_route(&client::logout_route) .ruma_route(&client::logout_all_route) diff --git a/src/api/router/auth/uiaa.rs b/src/api/router/auth/uiaa.rs index 8d7272ddf..28da817b4 100644 --- a/src/api/router/auth/uiaa.rs +++ b/src/api/router/auth/uiaa.rs @@ -33,10 +33,9 @@ where .unwrap_or(false) .await || (cfg!(feature = "ldap") && services.config.ldap.enable); - //TODO: UIAA for SSO. + // Check if user has SSO authentication available let sso_flow = [AuthType::Sso]; - let has_sso = false; - let _has_sso = sender_user + let has_sso = sender_user .map_async(|sender_user| { services .oauth diff --git a/src/router/layers.rs b/src/router/layers.rs index 75ead23cd..6cea6231e 100644 --- a/src/router/layers.rs +++ b/src/router/layers.rs @@ -34,6 +34,16 @@ const TUWUNEL_CSP: &[&str; 5] = &[ "sandbox", ]; +const TUWUNEL_HTML_CSP: &[&str; 7] = &[ + "default-src 'none'", + "script-src 'unsafe-inline'", + "style-src 'unsafe-inline'", + "frame-ancestors 'none'", + "form-action 'none'", + "base-uri 'none'", + "sandbox", +]; + const TUWUNEL_PERMISSIONS_POLICY: &[&str; 2] = &["interest-cohort=()", "browsing-topics=()"]; pub(crate) fn build(services: &Arc) -> Result<(Router, Guard)> { @@ -95,7 +105,11 @@ pub(crate) fn build(services: &Arc) -> Result<(Router, Guard)> { )) .layer(SetResponseHeaderLayer::if_not_present( header::CONTENT_SECURITY_POLICY, - HeaderValue::from_str(&TUWUNEL_CSP.join(";"))?, + |res: &http::Response<_>| { + let is_html = res.headers().get(header::CONTENT_TYPE).map_or(false, |v| v.to_str().unwrap_or_default().contains("text/html")); + let csp = if is_html { TUWUNEL_HTML_CSP.join(";") } else { TUWUNEL_CSP.join(";") }; + HeaderValue::from_str(&csp).ok() + }, )) .layer(cors_layer(server)) .layer(body_limit_layer(server)) diff --git a/src/service/uiaa/mod.rs b/src/service/uiaa/mod.rs index 51d866e6e..2af52907b 100644 --- a/src/service/uiaa/mod.rs +++ b/src/service/uiaa/mod.rs @@ -11,7 +11,7 @@ use ruma::{ }, }; use tuwunel_core::{ - Err, Result, debug_warn, err, error, extract, implement, + Err, Result, err, error, extract, implement, utils::{self, BoolExt, hash, string::EMPTY}, }; use tuwunel_database::{Deserialized, Json, Map}; @@ -180,8 +180,17 @@ pub async fn try_auth( return Ok((false, uiaainfo)); } }, - | AuthData::FallbackAcknowledgement(session) => { - debug_warn!("FallbackAcknowledgement: {session:?}"); + | AuthData::FallbackAcknowledgement(_session) => { + // FallbackAcknowledgement is used for SSO and other fallback flows. + // The SSO callback route marks the session as completed by adding AuthType::Sso. + if !uiaainfo.completed.contains(&AuthType::Sso) { + uiaainfo.auth_error = Some(StandardErrorBody { + kind: ErrorKind::forbidden(), + message: "SSO authentication not completed for this session.".to_owned(), + }); + + return Ok((false, uiaainfo)); + } }, | AuthData::Dummy(_) => { uiaainfo.completed.push(AuthType::Dummy); @@ -252,7 +261,7 @@ pub fn get_uiaa_request( } #[implement(Service)] -fn update_uiaa_session( +pub fn update_uiaa_session( &self, user_id: &UserId, device_id: &DeviceId, @@ -286,3 +295,26 @@ async fn get_uiaa_session( .deserialized() .map_err(|_| err!(Request(Forbidden("UIAA session does not exist.")))) } + +#[implement(Service)] +pub async fn get_uiaa_session_by_session_id( + &self, + session_id: &str, +) -> Option<(OwnedUserId, OwnedDeviceId, UiaaInfo)> { + use futures::{TryStreamExt, pin_mut}; + + // Iterate over keys only (fastest way without a secondary index) + let stream = self.db.userdevicesessionid_uiaainfo.keys::<(OwnedUserId, OwnedDeviceId, String)>(); + pin_mut!(stream); + + while let Ok(Some((user_id, device_id, session))) = stream.try_next().await { + if session == session_id { + // Found the key, now fetch the actual UiaaInfo + if let Ok(uiaainfo) = self.get_uiaa_session(&user_id, &device_id, session_id).await { + return Some((user_id, device_id, uiaainfo)); + } + } + } + + None +}