diff --git a/README.md b/README.md index f514bbe9..0f96c108 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,15 @@ There are tests trying out the OAuth2 flow which can be run with `cargo test`. You can also test the OAuth2 flow manually by running the flask application in `test_client/client.py`. +### Creating a migration + +To make a change to the database scheme, we use diesel migrations + +1. To create a migration, run `diesel migration generate ` +2. Fill in the generated `up.sql` and `down.sql` +3. Re-generate `src/models/schema.rs` by running `diesel print-schema` + > Caution: at the moment, the `users` schema cannot be generated correctly automatically. + ### Using Nix We have provided a [flake.nix](./flake.nix) for easy setup for Nix users. With [flakes enabled](https://nixos.wiki/wiki/Flakes), run `nix develop`. diff --git a/migrations/2025-01-23-180624_create_passkeys/up.sql b/migrations/2025-01-23-180624_create_passkeys/up.sql index 3ef0378a..4c257a6d 100644 --- a/migrations/2025-01-23-180624_create_passkeys/up.sql +++ b/migrations/2025-01-23-180624_create_passkeys/up.sql @@ -1,10 +1,10 @@ -- Your SQL goes here CREATE TABLE passkeys ( id SERIAL PRIMARY KEY, + user_id INTEGER NOT NULL REFERENCES users(id), name VARCHAR(255) NOT NULL, cred VARCHAR NOT NULL, cred_id VARCHAR NOT NULL, - user_id INTEGER NOT NULL REFERENCES users(id), last_used TIMESTAMP NOT NULL, created_at TIMESTAMP NOT NULL DEFAULT NOW() ); diff --git a/migrations/2025-07-01-174809_add_roles_table/down.sql b/migrations/2025-07-01-174809_add_roles_table/down.sql new file mode 100644 index 00000000..1b6a4420 --- /dev/null +++ b/migrations/2025-07-01-174809_add_roles_table/down.sql @@ -0,0 +1,3 @@ +-- This file should undo anything in `up.sql` +DROP TABLE users_roles; +DROP TABLE roles; diff --git a/migrations/2025-07-01-174809_add_roles_table/up.sql b/migrations/2025-07-01-174809_add_roles_table/up.sql new file mode 100644 index 00000000..e26e0f7b --- /dev/null +++ b/migrations/2025-07-01-174809_add_roles_table/up.sql @@ -0,0 +1,13 @@ +-- Your SQL goes here +CREATE TABLE roles ( + id SERIAL PRIMARY KEY, + name VARCHAR(255) NOT NULL UNIQUE, + description VARCHAR(255) NOT NULL, + client_id INTEGER REFERENCES clients(id) +); + +CREATE TABLE users_roles ( + user_id INTEGER REFERENCES users(id) ON DELETE CASCADE, + role_id INTEGER REFERENCES roles(id) ON DELETE CASCADE, + PRIMARY KEY (user_id, role_id) +) diff --git a/src/controllers/mod.rs b/src/controllers/mod.rs index 5e5c3e3e..e86ae4c0 100644 --- a/src/controllers/mod.rs +++ b/src/controllers/mod.rs @@ -2,6 +2,7 @@ pub mod clients_controller; pub mod mailing_list_controller; pub mod oauth_controller; pub mod pages_controller; +pub mod roles_controller; pub mod sessions_controller; pub mod users_controller; pub mod webauthn_controller; diff --git a/src/controllers/oauth_controller.rs b/src/controllers/oauth_controller.rs index f51efa98..30e2e6df 100644 --- a/src/controllers/oauth_controller.rs +++ b/src/controllers/oauth_controller.rs @@ -18,6 +18,7 @@ use crate::jwt::JWTBuilder; use crate::models::client::*; use crate::models::session::*; use crate::models::user::*; +use crate::util::split_scopes; use crate::ephemeral::session::ensure_logged_in_and_redirect; use crate::errors::OAuthError::InvalidCookie; @@ -327,24 +328,32 @@ pub async fn token( ))) } else { let user = User::find(token.user_id, &db).await?; - let id_token = token - .scope - .as_ref() - .map(|scope| -> Option { - match scope.contains("openid") { - true => { - jwt_builder.encode_id_token(&client, &user, config).ok() - }, - false => None, - } - }) - .flatten(); + let scopes = split_scopes(&token.scope); + let id_token = if scopes.contains(&"openid".into()) { + let roles = if scopes.contains(&"roles".into()) { + Some( + user.clone() + .roles_for_client(client.id, &db) + .await? + .iter() + .map(|r| r.clone().name) + .collect(), + ) + } else { + None + }; + jwt_builder + .encode_id_token(&client, &user, config, roles) + .ok() + } else { + None + }; let session = Session::create_client_session( &user, &client, token.scope, - &config, + config, &db, ) .await?; diff --git a/src/controllers/roles_controller.rs b/src/controllers/roles_controller.rs new file mode 100644 index 00000000..0dfd5ffe --- /dev/null +++ b/src/controllers/roles_controller.rs @@ -0,0 +1,166 @@ +use diesel::result::DatabaseErrorKind; +use rocket::form::Form; +use rocket::http::Status; +use rocket::response::status::Custom; +use rocket::response::{Redirect, Responder, status}; +use rocket::serde::json::Json; +use std::fmt::Debug; + +use crate::DbConn; +use crate::ephemeral::from_api::Api; +use crate::ephemeral::session::AdminSession; +use crate::errors::{Either, InternalError, Result, ZauthError}; +use crate::models::client::Client; +use crate::models::role::{NewRole, Role}; +use crate::models::user::User; +use crate::views::accepter::Accepter; + +#[get("/roles?")] +pub async fn list_roles<'r>( + error: Option, + db: DbConn, + session: AdminSession, +) -> Result> { + let roles = Role::all(&db).await?; + let clients = Client::all(&db).await?; + + Ok(Accepter { + html: template! { + "roles/index.html"; + roles: Vec = roles.clone(), + clients: Vec = clients, + error: Option = error, + current_user: User = session.admin, + }, + json: Json(roles), + }) +} + +#[post("/roles", data = "")] +pub async fn create_role<'r, 'a>( + role: Api, + db: DbConn, + _admin: AdminSession, +) -> Result< + Either, impl Responder<'r, 'static> + use<'r>>, +> { + let role = Role::create(role.into_inner(), &db).await; + match role { + Ok(role) => Ok(Either::Left(Accepter { + html: Redirect::to(uri!(list_roles(None::))), + json: status::Created::new(String::from("/role")).body(Json(role)), + })), + Err(ZauthError::Internal(InternalError::DatabaseError( + diesel::result::Error::DatabaseError( + DatabaseErrorKind::UniqueViolation, + _, + ), + ))) => Ok(Either::Right(Accepter { + html: Redirect::to(uri!(list_roles(Some( + "role name already exists" + )))), + json: "role name already exists", + })), + Err(err) => Err(err), + } +} + +#[get("/roles/?&")] +pub async fn show_role_page<'r>( + id: i32, + error: Option, + info: Option, + session: AdminSession, + db: DbConn, +) -> Result> { + let role = Role::find(id, &db).await?; + let users = role.clone().users(&db).await?; + + let client = if let Some(id) = role.client_id { + Some(Client::find(id, &db).await?) + } else { + None + }; + + Ok(template! { "roles/show_role.html"; + current_user: User = session.admin, + role: Role = role, + client: Option = client, + users: Vec = users, + error: Option = error, + info: Option = info + }) +} + +#[delete("/roles/")] +pub async fn delete_role<'r>( + id: i32, + _session: AdminSession, + db: DbConn, +) -> Result> { + let role = Role::find(id, &db).await?; + role.delete(&db).await?; + Ok(Accepter { + html: Redirect::to(uri!(list_roles(None::))), + json: Custom(Status::NoContent, ()), + }) +} + +#[post("/roles//users", data = "")] +pub async fn add_user<'r>( + username: Form, + role_id: i32, + db: DbConn, + _session: AdminSession, +) -> Result> { + let role = Role::find(role_id, &db).await?; + let user_result = User::find_by_username(username.clone(), &db).await; + Ok(match user_result { + Ok(user) => { + role.add_user(user.id, &db).await?; + Accepter { + html: Redirect::to(uri!(show_role_page( + role.id, + None::, + Some("user added") + ))), + json: Custom(Status::Ok, ()), + } + }, + Err(ZauthError::NotFound(_)) => Accepter { + html: Redirect::to(uri!(show_role_page( + role.id, + Some("user not found"), + None:: + ))), + json: Custom(Status::NotFound, ()), + }, + _ => Accepter { + html: Redirect::to(uri!(show_role_page( + role.id, + Some("error occured"), + None:: + ))), + json: Custom(Status::InternalServerError, ()), + }, + }) +} + +#[delete("/roles//users/")] +pub async fn delete_user<'r>( + role_id: i32, + user_id: i32, + _session: AdminSession, + db: DbConn, +) -> Result> { + let role = Role::find(role_id, &db).await?; + role.remove_user(user_id, &db).await?; + Ok(Accepter { + html: Redirect::to(uri!(show_role_page( + role_id, + None::, + Some("user deleted") + ))), + json: Custom(Status::Ok, ()), + }) +} diff --git a/src/controllers/users_controller.rs b/src/controllers/users_controller.rs index 9525955a..16395644 100644 --- a/src/controllers/users_controller.rs +++ b/src/controllers/users_controller.rs @@ -14,7 +14,10 @@ use crate::ephemeral::session::{ use crate::errors::Either::{self, Left, Right}; use crate::errors::{InternalError, OneOf, Result, ZauthError}; use crate::mailer::Mailer; +use crate::models::client::Client; +use crate::models::role::Role; use crate::models::user::*; +use crate::util::split_scopes; use crate::views::accepter::Accepter; use crate::{DbConn, util}; use askama::Template; @@ -23,14 +26,78 @@ use rocket::State; use rocket::form::Form; use rocket::serde::json::Json; +#[derive(Serialize)] +pub struct UserInfo { + id: i32, + username: String, + admin: bool, + full_name: String, + #[serde(skip_serializing_if = "Option::is_none")] + roles: Option>, +} + +impl UserInfo { + async fn new( + user: User, + client: Option, + scope: Option, + db: &DbConn, + ) -> Result { + let scopes = split_scopes(&scope); + + let roles = if let Some(client) = client { + if scopes.contains(&"roles".into()) { + Some( + user.clone() + .roles_for_client(client.id, db) + .await? + .iter() + .map(|r| r.clone().name) + .collect(), + ) + } else { + None + } + } else { + Some( + user.clone() + .roles(db) + .await? + .iter() + .map(|r| r.clone().name) + .collect(), + ) + }; + + Ok(UserInfo { + id: user.id, + username: user.username, + admin: user.admin, + full_name: user.full_name, + roles, + }) + } +} + #[get("/current_user")] -pub fn current_user(session: ClientOrUserSession) -> Json { - Json(session.user) +pub async fn current_user( + session: ClientOrUserSession, + db: DbConn, +) -> Result> { + Ok(Json( + UserInfo::new(session.user, session.client, session.scope, &db).await?, + )) } #[get("/current_user")] -pub fn current_user_as_client(session: ClientSession) -> Json { - Json(session.user) +pub async fn current_user_as_client( + session: ClientSession, + db: DbConn, +) -> Result> { + Ok(Json( + UserInfo::new(session.user, Some(session.client), session.scope, &db) + .await?, + )) } #[get("/users/")] @@ -41,12 +108,20 @@ pub async fn show_user<'r>( ) -> Result> { // Cloning the username is necessary because it's used later let user = User::find_by_username(username.clone(), &db).await?; + let user_roles = user.clone().roles(&db).await?; + let roles = if session.user.admin { + Role::all(&db).await? + } else { + vec![] + }; // Check whether the current session is allowed to view this user if session.user.admin || session.user.username == username { Ok(Accepter { html: template!("users/show.html"; user: User = user.clone(), current_user: User = session.user, + user_roles: Vec = user_roles, + roles: Vec = roles, errors: Option = None ), json: Json(user), @@ -123,9 +198,7 @@ pub async fn create_user<'r>( db: DbConn, config: &State, ) -> Result + use<'r>> { - let user = User::create(user.into_inner(), config.bcrypt_cost, &db) - .await - .map_err(ZauthError::from)?; + let user = User::create(user.into_inner(), config.bcrypt_cost, &db).await?; // Cloning the username is necessary because it's used later Ok(Accepter { html: Redirect::to(uri!(show_user(user.username.clone()))), @@ -227,15 +300,20 @@ pub async fn update_user<'r, 'o: 'r>( json: Custom(Status::NoContent, ()), })) }, - Err(ZauthError::ValidationError(errors)) => Ok(OneOf::Two(Custom( - Status::UnprocessableEntity, - template! { - "users/show.html"; - user: User = user, - current_user: User = session.user, - errors: Option = Some(errors.clone()), - }, - ))), + Err(ZauthError::ValidationError(errors)) => { + let roles = user.clone().roles(&db).await?; + Ok(OneOf::Two(Custom( + Status::UnprocessableEntity, + template! { + "users/show.html"; + user: User = user, + current_user: User = session.user, + user_roles: Vec = roles, + roles: Vec = vec![], + errors: Option = Some(errors.clone()), + }, + ))) + }, Err(other) => Err(other), } } else { @@ -542,3 +620,35 @@ pub async fn confirm_email_post<'r>( )) } } + +#[post("/users//roles", data = "")] +pub async fn add_role<'r>( + username: String, + role_id: Form, + db: DbConn, + _session: AdminSession, +) -> Result> { + let role = Role::find(*role_id, &db).await?; + let user_result = User::find_by_username(username.clone(), &db).await?; + role.add_user(user_result.id, &db).await?; + Ok(Accepter { + html: Redirect::to(uri!(show_user(username))), + json: Custom(Status::NoContent, ()), + }) +} + +#[delete("/users//roles/")] +pub async fn delete_role<'r>( + role_id: i32, + username: String, + _session: AdminSession, + db: DbConn, +) -> Result> { + let role = Role::find(role_id, &db).await?; + let user_result = User::find_by_username(username.clone(), &db).await?; + role.remove_user(user_result.id, &db).await?; + Ok(Accepter { + html: Redirect::to(uri!(show_user(username))), + json: Custom(Status::NoContent, ()), + }) +} diff --git a/src/db_seed.rs b/src/db_seed.rs index 407759ac..94322c30 100644 --- a/src/db_seed.rs +++ b/src/db_seed.rs @@ -1,8 +1,8 @@ use crate::DbConn; use crate::errors::{Result, ZauthError}; -use crate::models::client::schema::clients; use crate::models::client::{Client, NewClient}; -use crate::models::user::schema::users; +use crate::models::schema::clients; +use crate::models::schema::users; use crate::models::user::{NewUser, User}; use crate::util::random_token; use diesel::RunQueryDsl; diff --git a/src/ephemeral/session.rs b/src/ephemeral/session.rs index c1455975..a861ad0c 100644 --- a/src/ephemeral/session.rs +++ b/src/ephemeral/session.rs @@ -157,6 +157,7 @@ impl<'r> FromRequest<'r> for AdminSession { pub struct ClientSession { pub user: User, pub client: Client, + pub scope: Option, } #[rocket::async_trait] @@ -197,9 +198,11 @@ impl<'r> FromRequest<'r> for ClientSession { match Session::find_by_key(key.to_string(), &db).await { Ok(session) => match session.user(&db).await { Ok(user) => match session.client(&db).await { - Ok(Some(client)) => { - Outcome::Success(ClientSession { user, client }) - }, + Ok(Some(client)) => Outcome::Success(ClientSession { + user, + client, + scope: session.scope, + }), _ => Outcome::Error(( Status::Unauthorized, "there is no client associated to this client session", @@ -222,6 +225,7 @@ impl<'r> FromRequest<'r> for ClientSession { pub struct ClientOrUserSession { pub user: User, pub client: Option, + pub scope: Option, } #[rocket::async_trait] @@ -236,6 +240,7 @@ impl<'r> FromRequest<'r> for ClientOrUserSession { Outcome::Success(ClientOrUserSession { user: session.user, client: None, + scope: None, }) }, _ => match request.guard::().await { @@ -243,6 +248,7 @@ impl<'r> FromRequest<'r> for ClientOrUserSession { Outcome::Success(ClientOrUserSession { user: session.user, client: Some(session.client), + scope: session.scope, }) }, _ => Outcome::Error(( diff --git a/src/jwt.rs b/src/jwt.rs index 85b94ada..41e9ce4a 100644 --- a/src/jwt.rs +++ b/src/jwt.rs @@ -30,6 +30,8 @@ struct IDToken { iat: i64, preferred_username: String, email: String, + #[serde(skip_serializing_if = "Option::is_none")] + roles: Option>, } impl JWTBuilder { @@ -97,6 +99,7 @@ impl JWTBuilder { client: &Client, user: &User, config: &Config, + roles: Option>, ) -> Result { let id_token = IDToken { sub: user.id.to_string(), @@ -106,6 +109,7 @@ impl JWTBuilder { exp: Utc::now().timestamp() + config.client_session_seconds, preferred_username: user.username.clone(), email: user.email.clone(), + roles, }; self.encode(&id_token) } diff --git a/src/lib.rs b/src/lib.rs index 0cfafa6d..fd959f69 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -147,10 +147,18 @@ fn assemble(rocket: Rocket) -> Rocket { users_controller::confirm_email_post, users_controller::show_confirm_unsubscribe, users_controller::unsubscribe_user, + users_controller::add_role, + users_controller::delete_role, mailing_list_controller::list_mails, mailing_list_controller::send_mail, mailing_list_controller::show_create_mail_page, mailing_list_controller::show_mail, + roles_controller::list_roles, + roles_controller::create_role, + roles_controller::delete_role, + roles_controller::show_role_page, + roles_controller::add_user, + roles_controller::delete_user, ], ) .register( diff --git a/src/models/client.rs b/src/models/client.rs index 29435388..21314607 100644 --- a/src/models/client.rs +++ b/src/models/client.rs @@ -3,27 +3,15 @@ use diesel::{self, prelude::*}; use crate::DbConn; use crate::errors::{AuthenticationError, Result, ZauthError}; -use self::schema::clients; +use crate::models::schema::{clients, roles}; use crate::util::random_token; use chrono::NaiveDateTime; use validator::Validate; -const SECRET_LENGTH: usize = 64; +use super::role::Role; -pub mod schema { - table! { - clients { - id -> Integer, - name -> Text, - description -> Text, - secret -> Text, - needs_grant -> Bool, - redirect_uri_list -> Text, - created_at -> Timestamp, - } - } -} +const SECRET_LENGTH: usize = 64; #[derive(Serialize, AsChangeset, Queryable, Debug, Clone)] pub struct Client { @@ -176,4 +164,16 @@ impl Client { self.secret = Self::generate_random_secret(); self.update(db).await } + + pub async fn roles(&self, db: &DbConn) -> Result> { + let id = self.id; + db.run(move |conn| { + roles::table + .filter(roles::client_id.eq(id)) + .select(Role::as_select()) + .get_results(conn) + }) + .await + .map_err(ZauthError::from) + } } diff --git a/src/models/mail.rs b/src/models/mail.rs index 7041f754..48d62e17 100644 --- a/src/models/mail.rs +++ b/src/models/mail.rs @@ -1,6 +1,5 @@ use std::cmp::Reverse; -use self::schema::mails; use crate::DbConn; use crate::errors::{self, ZauthError}; use chrono::NaiveDateTime; @@ -10,19 +9,7 @@ use diesel::result::Error as DieselError; use rocket::serde::Serialize; use validator::Validate; -pub mod schema { - table! { - use diesel::sql_types::*; - - mails { - id -> Integer, - sent_on -> Timestamp, - subject -> Text, - body -> Text, - author -> Varchar, - } - } -} +use super::schema::mails; #[derive(Clone, Debug, Queryable, Serialize)] #[serde(crate = "rocket::serde")] diff --git a/src/models/mod.rs b/src/models/mod.rs index 2f6533a6..d22d05a5 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -1,5 +1,7 @@ pub mod client; pub mod mail; pub mod passkey; +pub mod role; +pub mod schema; pub mod session; pub mod user; diff --git a/src/models/passkey.rs b/src/models/passkey.rs index 9c14ec90..a4c75296 100644 --- a/src/models/passkey.rs +++ b/src/models/passkey.rs @@ -8,24 +8,7 @@ use crate::{ errors::{self, InternalError, Result, ZauthError}, }; -use self::schema::passkeys; - -pub mod schema { - table! { - use diesel::sql_types::*; - - passkeys { - id -> Integer, - user_id -> Integer, - name -> VarChar, - cred -> VarChar, - cred_id -> VarChar, - last_used -> Timestamp, - created_at -> Timestamp, - } - - } -} +use super::schema::passkeys; #[derive( Queryable, Selectable, PartialEq, Debug, Clone, Serialize, AsChangeset, @@ -36,9 +19,9 @@ pub struct PassKey { pub user_id: i32, pub name: String, #[serde(skip)] - cred: String, + pub cred: String, #[serde(skip)] - cred_id: String, + pub cred_id: String, pub last_used: NaiveDateTime, pub created_at: NaiveDateTime, } diff --git a/src/models/role.rs b/src/models/role.rs new file mode 100644 index 00000000..cf3427a1 --- /dev/null +++ b/src/models/role.rs @@ -0,0 +1,142 @@ +use crate::{ + DbConn, + errors::{Result, ZauthError}, +}; +use diesel::{self, prelude::*}; +use validator::Validate; + +use crate::models::schema::{roles, users, users_roles}; +use crate::models::user::User; + +#[derive( + Deserialize, + Serialize, + Queryable, + Debug, + Clone, + Identifiable, + PartialEq, + Selectable, +)] +pub struct Role { + pub id: i32, + pub name: String, + pub description: String, + pub client_id: Option, +} + +#[derive(Validate, FromForm, Debug, Insertable, Deserialize)] +#[diesel(table_name = roles)] +pub struct NewRole { + #[validate(length(min = 1, max = 30))] + pub name: String, + #[validate(length(min = 1, max = 100))] + pub description: String, + pub client_id: Option, +} + +#[derive( + Identifiable, Selectable, Queryable, Associations, Debug, Insertable, +)] +#[diesel(belongs_to(Role))] +#[diesel(belongs_to(User))] +#[diesel(table_name = users_roles)] +#[diesel(primary_key(role_id, user_id))] +pub struct UserRole { + pub role_id: i32, + pub user_id: i32, +} + +impl Role { + pub async fn create(role: NewRole, db: &DbConn) -> Result { + role.validate()?; + + db.run(move |conn| { + diesel::insert_into(roles::table) + .values(&role) + .get_result::(conn) + }) + .await + .map_err(ZauthError::from) + } + + pub async fn add_user(&self, user_id: i32, db: &DbConn) -> Result { + let id = self.id; + let user_role = db + .run(move |conn| { + users_roles::table + .filter(users_roles::user_id.eq(user_id)) + .filter(users_roles::role_id.eq(id)) + .first::(conn) + .optional() + }) + .await + .map_err(ZauthError::from)?; + + if user_role.is_none() { + // UserRole not already exists + let user_role = UserRole { + role_id: self.id, + user_id, + }; + db.run(move |conn| { + diesel::insert_into(users_roles::table) + .values(&user_role) + .execute(conn) + }) + .await + .map_err(ZauthError::from)?; + Ok(true) + } else { + Ok(false) + } + } + + pub async fn remove_user(self, user_id: i32, db: &DbConn) -> Result { + let count = db + .run(move |conn| { + diesel::delete( + users_roles::table + .filter(users_roles::user_id.eq(user_id)) + .filter(users_roles::role_id.eq(self.id)), + ) + .execute(conn) + }) + .await + .map_err(ZauthError::from)?; + Ok(count > 0) + } + + pub async fn users(self, db: &DbConn) -> Result> { + db.run(move |conn| { + UserRole::belonging_to(&self) + .inner_join(users::table) + .select(User::as_select()) + .load(conn) + }) + .await + .map_err(ZauthError::from) + } + + pub async fn find(id: i32, db: &DbConn) -> Result { + db.run(move |conn| diesel::QueryDsl::find(roles::table, id).first(conn)) + .await + .map_err(ZauthError::from) + } + + pub async fn all(db: &DbConn) -> Result> { + let all_roles = + db.run(move |conn| roles::table.load::(conn)).await?; + Ok(all_roles) + } + + pub async fn delete(self, db: &DbConn) -> Result<()> { + db.run(move |conn| { + diesel::delete(roles::table.filter(roles::id.eq(self.id))) + .execute(conn) + }) + .await + .map_err(ZauthError::from)?; + Ok(()) + } +} diff --git a/src/models/schema.rs b/src/models/schema.rs new file mode 100644 index 00000000..76fa7ff2 --- /dev/null +++ b/src/models/schema.rs @@ -0,0 +1,119 @@ +// @generated automatically by Diesel CLI. + +pub mod sql_types { + #[derive(diesel::sql_types::SqlType)] + #[diesel(postgres_type(name = "user_state"))] + pub struct UserState; +} + +diesel::table! { + clients (id) { + id -> Int4, + #[max_length = 255] + name -> Varchar, + description -> Text, + #[max_length = 255] + secret -> Varchar, + needs_grant -> Bool, + redirect_uri_list -> Text, + created_at -> Timestamp, + } +} + +diesel::table! { + mails (id) { + id -> Int4, + sent_on -> Timestamp, + subject -> Text, + body -> Text, + #[max_length = 255] + author -> Varchar, + } +} + +diesel::table! { + passkeys (id) { + id -> Integer, + user_id -> Integer, + #[max_length = 255] + name -> Varchar, + cred -> Varchar, + cred_id -> Varchar, + last_used -> Timestamp, + created_at -> Timestamp, + } +} + +diesel::table! { + roles (id) { + id -> Int4, + #[max_length = 255] + name -> Varchar, + #[max_length = 255] + description -> Varchar, + client_id -> Nullable, + } +} + +diesel::table! { + sessions (id) { + id -> Int4, + #[max_length = 255] + key -> Nullable, + user_id -> Int4, + client_id -> Nullable, + created_at -> Timestamp, + expires_at -> Timestamp, + valid -> Bool, + scope -> Nullable, + } +} + +diesel::table! { + use diesel::sql_types::*; + use crate::models::user::UserStateMapping; + + users { + id -> Int4, + username -> Varchar, + hashed_password -> Varchar, + admin -> Bool, + password_reset_token -> Nullable, + password_reset_expiry -> Nullable, + full_name -> Varchar, + email -> Varchar, + pending_email -> Nullable, + pending_email_token -> Nullable, + pending_email_expiry -> Nullable, + ssh_key -> Nullable, + state -> UserStateMapping, + last_login -> Timestamp, + created_at -> Timestamp, + subscribed_to_mailing_list -> Bool, + unsubscribe_token -> Varchar, + } +} + +diesel::table! { + users_roles (user_id, role_id) { + user_id -> Int4, + role_id -> Int4, + } +} + +diesel::joinable!(passkeys -> users (user_id)); +diesel::joinable!(roles -> clients (client_id)); +diesel::joinable!(sessions -> clients (client_id)); +diesel::joinable!(sessions -> users (user_id)); +diesel::joinable!(users_roles -> roles (role_id)); +diesel::joinable!(users_roles -> users (user_id)); + +diesel::allow_tables_to_appear_in_same_query!( + clients, + mails, + passkeys, + roles, + sessions, + users, + users_roles, +); diff --git a/src/models/user.rs b/src/models/user.rs index 07bb6757..b495a774 100644 --- a/src/models/user.rs +++ b/src/models/user.rs @@ -1,4 +1,4 @@ -use self::schema::users; +use super::schema::{roles, users}; use crate::DbConn; use crate::errors::{self, InternalError, LoginError, ZauthError}; use diesel::{self, prelude::*}; @@ -17,6 +17,8 @@ use rocket::{FromFormField, serde::Serialize}; use std::convert::TryFrom; use validator::{Validate, ValidationError, ValidationErrors}; +use super::role::{Role, UserRole}; + #[derive( DbEnum, Debug, Deserialize, FromFormField, Serialize, Clone, PartialEq, )] @@ -42,34 +44,17 @@ impl fmt::Display for UserState { } } -pub mod schema { - table! { - use diesel::sql_types::*; - use crate::models::user::UserStateMapping; - - users { - id -> Integer, - username -> Varchar, - hashed_password -> Varchar, - admin -> Bool, - password_reset_token -> Nullable, - password_reset_expiry -> Nullable, - full_name -> Varchar, - email -> Varchar, - pending_email -> Nullable, - pending_email_token -> Nullable, - pending_email_expiry -> Nullable, - ssh_key -> Nullable, - state -> UserStateMapping, - last_login -> Timestamp, - created_at -> Timestamp, - subscribed_to_mailing_list -> Bool, - unsubscribe_token -> Varchar, - } - } -} - -#[derive(Validate, Serialize, AsChangeset, Queryable, Debug, Clone)] +#[derive( + Validate, + Serialize, + AsChangeset, + Selectable, + Queryable, + Debug, + Clone, + PartialEq, + Identifiable, +)] #[diesel(table_name = users)] #[diesel(treat_none_as_null = true)] #[serde(crate = "rocket::serde")] @@ -541,6 +526,37 @@ impl User { Err(e) => Err(ZauthError::from(e)), } } + + pub async fn roles(self, db: &DbConn) -> Result, ZauthError> { + db.run(move |conn| { + UserRole::belonging_to(&self) + .inner_join(roles::table) + .select(Role::as_select()) + .load(conn) + }) + .await + .map_err(ZauthError::from) + } + + pub async fn roles_for_client( + self, + client_id: i32, + db: &DbConn, + ) -> Result, ZauthError> { + db.run(move |conn| { + UserRole::belonging_to(&self) + .inner_join(roles::table) + .filter( + roles::client_id + .eq(client_id) + .or(roles::client_id.is_null()), + ) + .select(Role::as_select()) + .load(conn) + }) + .await + .map_err(ZauthError::from) + } } fn hash( diff --git a/src/util.rs b/src/util.rs index a00bf86f..b096cb48 100644 --- a/src/util.rs +++ b/src/util.rs @@ -10,6 +10,12 @@ pub fn random_token(token_length: usize) -> String { .collect() } +pub fn split_scopes(scope: &Option) -> Vec { + scope + .as_ref() + .map_or(vec![], |scope| scope.split(" ").map(String::from).collect()) +} + // pub use dev::seed_database; // // mod dev { diff --git a/templates/base_logged_in.html b/templates/base_logged_in.html index 7fa2f31f..18e146ba 100644 --- a/templates/base_logged_in.html +++ b/templates/base_logged_in.html @@ -21,6 +21,7 @@ {% if current_user.admin %} Users Clients + Roles {% endif %} Logout diff --git a/templates/roles/index.html b/templates/roles/index.html new file mode 100644 index 00000000..e41fc2e4 --- /dev/null +++ b/templates/roles/index.html @@ -0,0 +1,134 @@ +{% extends "base_logged_in.html" %} + + +{% block content %} +
+ + +
+
+ +
+ Roles ({{ roles.len() }}) +
+ +
+ Roles are given to users, which allows for more fine-grained client-level permissions.
+ If a client requests the user info after login, it can give this user additional permissions based on the included roles.
+ Global roles are always returned; client-specific roles are only returned for that client.
+ The OAuth scope must include 'roles' for the roles to be included in the ID token or user info. +
+ + +
+ +
+ New role +
+
+ + + {% match error %} + {% when Some with (error) %} +
+ {{ error }} +
+ {% when None %} + {% endmatch %} + + + + + + + + + + + + + {% for role in roles %} + + + + + + + + + + + + {% endfor %} + + + {% if roles.len() == 0 %} + + + + {% endif %} + +
NameDescription
+ {{ role.name }} + {{ role.description }} +
+ + +
+
No roles configured
+
+ + + +{% endblock content %} diff --git a/templates/roles/show_role.html b/templates/roles/show_role.html new file mode 100644 index 00000000..4108a51f --- /dev/null +++ b/templates/roles/show_role.html @@ -0,0 +1,116 @@ +{% extends "base_logged_in.html" %} + +{% block content %} +
+ +
+
+ +
+ Users with role {{role.name}} +
+ + +
+ {% match client %} + {% when Some with (client) %} + for client {{client.name}}. + {% when None %} + Global role. + {% endmatch %} +
+ +
+ Add user +
+ +
+
+ + + {% match error %} + {% when Some with (error) %} +
+ {{ error }} +
+ {% when None %} + {% endmatch %} + + + {% match info %} + {% when Some with (info) %} +
+ {{ info }} +
+ {% when None %} + {% endmatch %} + + + + + + + + + + + + {% for user in users %} + + + + + + + + + {% endfor %} + + + {% if users.len() == 0 %} + + + + {% endif %} + +
Name
+ {{ user.username }} + +
+ + +
+
No users mapped yet
+
+ + + +{% endblock content %} diff --git a/templates/users/show.html b/templates/users/show.html index 95bb3aba..6e9c067d 100644 --- a/templates/users/show.html +++ b/templates/users/show.html @@ -86,7 +86,7 @@ {% if current_user.admin %} -
+
@@ -101,6 +101,64 @@
{% endif %} + +
+ +
+ {% for role in user_roles %} + {% if let Some(_) = role.client_id %} + + {% else %} + + {% endif %} + {% if current_user.admin %} + {{ role.name }} + + + {% else -%} + {{ role.name }} + {% endif %} + + {% endfor %} +
+ + {% if current_user.admin %} +
+
+ +
+ + +
+ + {% endif %} +
+
{% endblock content %} diff --git a/tests/clients.rs b/tests/clients.rs index 33119608..5ae90d59 100644 --- a/tests/clients.rs +++ b/tests/clients.rs @@ -18,7 +18,7 @@ async fn create_and_update_client() { common::as_admin(async move |http_client: HttpClient, db, _user| { let client_name = "test"; - let client_form = format!("name={}", url(&client_name),); + let client_form = format!("name={}", url(client_name),); let response = http_client .post("/clients") diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 98aec352..457a59e1 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -58,7 +58,7 @@ pub fn config() -> Config { async fn reset_db(db: &DbConn) { db.run(|conn| { - sql_query("TRUNCATE TABLE mails, sessions, users, clients, passkeys") + sql_query("TRUNCATE TABLE mails, sessions, users, clients, passkeys, users_roles, roles") .execute(conn) .expect("drop all tables"); }) diff --git a/tests/oauth.rs b/tests/oauth.rs index 8d8e9a48..56c64a5b 100644 --- a/tests/oauth.rs +++ b/tests/oauth.rs @@ -20,6 +20,8 @@ use rocket::http::{Accept, ContentType}; use zauth::DbConn; use zauth::controllers::oauth_controller::UserToken; use zauth::models::client::{Client, NewClient}; +use zauth::models::role::NewRole; +use zauth::models::role::Role; use zauth::models::user::{NewUser, User}; use zauth::token_store::TokenStore; @@ -57,21 +59,29 @@ async fn create_user(db: &DbConn) -> User { .expect("user") } -async fn create_client(db: &DbConn) -> Client { - let mut client = Client::create( - NewClient { - name: String::from(CLIENT_ID), - }, - &db, - ) - .await - .expect("client created"); +async fn create_client(db: &DbConn, name: &str) -> Client { + let mut client = Client::create(NewClient { name: name.into() }, &db) + .await + .expect("client created"); client.needs_grant = true; client.redirect_uri_list = String::from(REDIRECT_URI); client.update(db).await.expect("client updated") } +async fn create_role(db: &DbConn, name: &str, client_id: Option) -> Role { + Role::create( + NewRole { + name: name.into(), + description: "test".into(), + client_id, + }, + db, + ) + .await + .expect("role created") +} + // Test all the usual oauth requests until `access_token/id_token` is retrieved. async fn get_token( authorize_url: String, @@ -209,7 +219,7 @@ async fn get_token( async fn normal_flow() { common::as_visitor(async move |http_client, db| { let user = create_user(&db).await; - let client = create_client(&db).await; + let client = create_client(&db, CLIENT_ID).await; // 1. User is redirected to OAuth server with request params given by // the client @@ -300,6 +310,7 @@ async fn normal_flow() { assert!(data["id"].is_number()); assert_eq!(data["username"], USER_USERNAME); + assert_eq!(data["roles"], Value::Null); }) .await; } @@ -308,7 +319,7 @@ async fn normal_flow() { async fn openid_flow() { common::as_visitor(async move |http_client, db| { let user = create_user(&db).await; - let client = create_client(&db).await; + let client = create_client(&db, CLIENT_ID).await; let authorize_url = format!( "/oauth/authorize?response_type=code&redirect_uri={}&client_id={}&\ @@ -347,6 +358,102 @@ async fn openid_flow() { .claims; assert_eq!(id_token["preferred_username"], USER_USERNAME); assert_eq!(id_token["email"], USER_EMAIL); + assert_eq!(id_token["roles"], Value::Null); + }) + .await; +} + +#[rocket::async_test] +async fn roles_flow() { + common::as_visitor(async move |http_client, db| { + let user = create_user(&db).await; + let client = create_client(&db, CLIENT_ID).await; + let client_not_used = create_client(&db, "not_used").await; + let role_global = create_role(&db, "global", None).await; + let role_client = create_role(&db, "client", Some(client.id)).await; + let role_client_not_used = + create_role(&db, "client_not_used", Some(client_not_used.id)).await; + let _role_global_not_mapped = + create_role(&db, "global_not_mapped", None).await; + let _role_client_not_mapped = + create_role(&db, "client_not_mapped", Some(client.id)).await; + + role_global + .add_user(user.id, &db) + .await + .expect("add user to global role"); + role_client + .add_user(user.id, &db) + .await + .expect("add user to user role"); + role_client_not_used + .add_user(user.id, &db) + .await + .expect("add user to user role"); + + let authorize_url = format!( + "/oauth/authorize?response_type=code&redirect_uri={}&client_id={}&\ + state={}&scope={}", + url(REDIRECT_URI), + url(CLIENT_ID), + url(CLIENT_STATE), + url("openid roles") + ); + + let data = get_token(authorize_url, &http_client, &client, &user).await; + + let url = "/oauth/jwks"; + let req = http_client.get(url); + let response = req.dispatch().await; + let response_body = + response.into_string().await.expect("response body"); + let jwk_set: JwkSet = + serde_json::from_str(&response_body).expect("response json values"); + + let mut validation = Validation::new(jsonwebtoken::Algorithm::ES384); + validation.set_audience(&[CLIENT_ID]); + validation.set_issuer(&["http://localhost:8000"]); + + let id_token = jsonwebtoken::decode::( + data["id_token"].as_str().unwrap(), + &DecodingKey::from_jwk(&jwk_set.keys.get(0).unwrap()).unwrap(), + &validation, + ) + .expect("id token") + .claims; + + let roles: Vec = + serde_json::from_value(id_token["roles"].clone()) + .expect("roles in id token"); + + assert_eq!(roles.len(), 2); + assert!( + (roles[0] == "global" && roles[1] == "client") + || (roles[0] == "client" && roles[1] == "global") + ); + + let token = data["access_token"].as_str().expect("access token"); + + let response = http_client + .get("/current_user") + .header(Accept::JSON) + .header(Header::new("Authorization", format!("Bearer {}", token))) + .dispatch() + .await; + + let response_body = + response.into_string().await.expect("response body"); + let data: Value = + serde_json::from_str(&response_body).expect("response json values"); + + let roles: Vec = serde_json::from_value(data["roles"].clone()) + .expect("roles in id token"); + + assert_eq!(roles.len(), 2); + assert!( + (roles[0] == "global" && roles[1] == "client") + || (roles[0] == "client" && roles[1] == "global") + ); }) .await; } diff --git a/tests/roles.rs b/tests/roles.rs new file mode 100644 index 00000000..aaea051b --- /dev/null +++ b/tests/roles.rs @@ -0,0 +1,319 @@ +use common::{HttpClient, url}; +use rocket::http::{Accept, ContentType, Status}; +use zauth::models::{ + client::{Client, NewClient}, + role::{NewRole, Role}, + user::User, +}; + +mod common; + +#[rocket::async_test] +async fn list_roles_as_user() { + common::as_user(async move |http_client: HttpClient, _db, _user: User| { + let response = http_client.get("/roles").dispatch().await; + + assert_eq!(response.status(), Status::Forbidden); + }) + .await; +} + +#[rocket::async_test] +async fn list_roles_as_admin() { + common::as_admin(async move |http_client: HttpClient, _db, _user: User| { + let response = http_client.get("/roles").dispatch().await; + + assert_eq!(response.status(), Status::Ok); + }) + .await; +} + +#[rocket::async_test] +async fn create_role_as_user() { + common::as_user(async move |http_client: HttpClient, _db, _user: User| { + let role_name = "test"; + let role_form = + format!("name={role_name}&description=test_description"); + let response = http_client + .post("/roles") + .body(role_form) + .header(ContentType::Form) + .header(Accept::JSON) + .dispatch() + .await; + + assert_eq!(response.status(), Status::Forbidden); + }) + .await; +} + +#[rocket::async_test] +async fn create_global_role() { + common::as_admin(async move |http_client: HttpClient, db, _user| { + let role_name = "test"; + let role_form = + format!("name={role_name}&description=test_description"); + + let response = http_client + .post("/roles") + .body(role_form) + .header(ContentType::Form) + .header(Accept::JSON) + .dispatch() + .await; + + assert_eq!(response.status(), Status::Created); + + let json: Role = response.into_json().await.unwrap(); + + let created = Role::find(json.id, &db).await.unwrap(); + + assert_eq!(created.name, role_name); + assert_eq!(created.description, "test_description"); + assert_eq!(created.client_id, None); + }) + .await; +} + +#[rocket::async_test] +async fn create_client_role() { + common::as_admin(async move |http_client: HttpClient, db, _user| { + let client = Client::create( + NewClient { + name: String::from("test"), + }, + &db, + ) + .await + .unwrap(); + + let role_name = "test"; + let role_form = format!( + "name={role_name}&description=test_description&client_id={}", + client.id + ); + + let response = http_client + .post("/roles") + .body(role_form) + .header(ContentType::Form) + .header(Accept::JSON) + .dispatch() + .await; + + assert_eq!(response.status(), Status::Created); + + let json: Role = response.into_json().await.unwrap(); + + let created = Role::find(json.id, &db).await.unwrap(); + + assert_eq!(created.name, role_name); + assert_eq!(created.description, "test_description"); + assert_eq!(created.client_id, Some(client.id)); + }) + .await; +} + +#[rocket::async_test] +async fn show_role_as_user() { + common::as_user(async move |http_client: HttpClient, db, _user: User| { + let role = Role::create( + NewRole { + name: "test".into(), + description: "test".into(), + client_id: None, + }, + &db, + ) + .await + .unwrap(); + + let response = http_client + .get(format!("/roles/{}", role.id)) + .dispatch() + .await; + + assert_eq!(response.status(), Status::Forbidden); + }) + .await; +} + +#[rocket::async_test] +async fn show_role_as_admin() { + common::as_admin(async move |http_client: HttpClient, db, _user: User| { + let role = Role::create( + NewRole { + name: "test".into(), + description: "test".into(), + client_id: None, + }, + &db, + ) + .await + .unwrap(); + + let response = http_client + .get(format!("/roles/{}", role.id)) + .dispatch() + .await; + + assert_eq!(response.status(), Status::Ok); + }) + .await; +} +#[rocket::async_test] +async fn delete_role() { + common::as_admin(async move |http_client: HttpClient, db, _user: User| { + let role = Role::create( + NewRole { + name: "test".into(), + description: "test".into(), + client_id: None, + }, + &db, + ) + .await + .unwrap(); + + let response = http_client + .delete(format!("/roles/{}", role.id)) + .dispatch() + .await; + + assert_eq!(response.status(), Status::SeeOther); + + assert!(Role::find(role.id, &db).await.is_err()); + }) + .await; +} + +#[rocket::async_test] +async fn add_user_to_role_as_user() { + common::as_user(async move |http_client: HttpClient, db, user: User| { + let role = Role::create( + NewRole { + name: "test".into(), + description: "test".into(), + client_id: None, + }, + &db, + ) + .await + .unwrap(); + + let role_form = format!("username={}", url(&user.username)); + + let response = http_client + .post(format!("/roles/{}/users", role.id)) + .body(role_form) + .header(ContentType::Form) + .header(Accept::JSON) + .dispatch() + .await; + + assert_eq!(response.status(), Status::Forbidden); + }) + .await; +} + +#[rocket::async_test] +async fn add_user_to_role_as_admin() { + common::as_admin(async move |http_client: HttpClient, db, user: User| { + let role = Role::create( + NewRole { + name: "test".into(), + description: "test".into(), + client_id: None, + }, + &db, + ) + .await + .unwrap(); + + let role_form = format!("username={}", url(&user.username)); + + let response = http_client + .post(format!("/roles/{}/users", role.id)) + .body(role_form) + .header(ContentType::Form) + .dispatch() + .await; + + assert_eq!(response.status(), Status::SeeOther); + + let users = role.clone().users(&db).await.unwrap(); + assert_eq!(users.len(), 1); + assert_eq!(users[0].id, user.id); + + let response = http_client + .delete(format!("/roles/{}/users/{}", role.id, user.id)) + .dispatch() + .await; + assert_eq!(response.status(), Status::SeeOther); + + let users = role.clone().users(&db).await.unwrap(); + assert_eq!(users.len(), 0); + }) + .await; +} + +#[rocket::async_test] +async fn add_role_to_user_as_user() { + common::as_user(async move |http_client: HttpClient, db, user: User| { + let role = Role::create( + NewRole { + name: "test".into(), + description: "test".into(), + client_id: None, + }, + &db, + ) + .await + .unwrap(); + + let role_form = format!("role_id={}", role.id); + + let response = http_client + .post(format!("/users/{}/roles", user.username)) + .body(role_form) + .header(ContentType::Form) + .header(Accept::JSON) + .dispatch() + .await; + + assert_eq!(response.status(), Status::Forbidden); + }) + .await; +} + +#[rocket::async_test] +async fn add_role_to_user_as_admin() { + common::as_admin(async move |http_client: HttpClient, db, user: User| { + let role = Role::create( + NewRole { + name: "test".into(), + description: "test".into(), + client_id: None, + }, + &db, + ) + .await + .unwrap(); + + let role_form = format!("role_id={}", role.id); + + let response = http_client + .post(format!("/users/{}/roles", user.username)) + .body(role_form) + .header(ContentType::Form) + .dispatch() + .await; + + assert_eq!(response.status(), Status::SeeOther); + + let users = role.clone().users(&db).await.unwrap(); + assert_eq!(users.len(), 1); + assert_eq!(users[0].id, user.id); + }) + .await; +}