diff --git a/Cargo.lock b/Cargo.lock index f0b671a..72289fa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -297,6 +297,8 @@ dependencies = [ "chrono", "etke_openai_api_rust", "matrix-sdk", + "mime", + "mime_guess", "mxidwc", "mxlink", "quick_cache", diff --git a/Cargo.toml b/Cargo.toml index 1e9b35e..bbdf54d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,6 +37,8 @@ tokio = { version = "1.48.*", features = ["rt", "rt-multi-thread", "macros"] } tracing = "0.1.*" tracing-subscriber = { version = "0.3.*", features = ["env-filter"] } url = "2.5.*" +mime_guess = "2.0.5" +mime = "0.3.17" [profile.release] strip = true diff --git a/etc/app/config.yml.dist b/etc/app/config.yml.dist index 2785301..177bfc9 100644 --- a/etc/app/config.yml.dist +++ b/etc/app/config.yml.dist @@ -11,6 +11,11 @@ user: # Leave empty to use the default (baibot). name: baibot + avatar: + # An optional path to an image file to be used as a custom avatar image. + # Leave empty to use the default. + source: + encryption: # An optional passphrase to use for backing up and recovering the bot's encryption keys. # You can use any string here. diff --git a/src/bot/implementation.rs b/src/bot/implementation.rs index c55a3a3..0dff359 100644 --- a/src/bot/implementation.rs +++ b/src/bot/implementation.rs @@ -1,5 +1,7 @@ use std::sync::Arc; +use std::fs; use std::{future::Future, pin::Pin}; +use mime::Mime; use mxlink::matrix_sdk::Room; use mxlink::matrix_sdk::media::{MediaFormat, MediaRequestParameters}; @@ -316,6 +318,19 @@ impl Bot { } } + let logo_bytes: Vec = if self.inner.config.user.avatar.is_none() { + LOGO_BYTES.to_vec() + } else { + let avatar = self.inner.config.user.avatar.as_ref().unwrap(); + let avatar_path = &avatar.source; + match fs::read(avatar_path) { + Ok(bytes) => bytes, + Err(_e) => { + LOGO_BYTES.to_vec() + } + } + }; + let should_update_avatar = match ¤t_avatar_url { Some(avatar_url) => { let request = MediaRequestParameters { @@ -328,20 +343,24 @@ impl Bot { .await .map_err(|e| anyhow::anyhow!("Failed fetching existing avatar: {:?}", e))?; - content.as_slice() != LOGO_BYTES + content.as_slice() != logo_bytes } None => true, }; if should_update_avatar { - tracing::info!("Updating avatar.."); + let mime_type: Mime = if self.inner.config.user.avatar.is_none() { + LOGO_MIME_TYPE.parse::().expect("Failed parsing mime type for logo") + } else { + let avatar = self.inner.config.user.avatar.as_ref().unwrap(); + let avatar_path = &avatar.source; + mime_guess::guess_mime_type(avatar_path) + }; - let mime_type = LOGO_MIME_TYPE - .parse() - .expect("Failed parsing mime type for logo"); + tracing::info!("Updating avatar.."); account - .upload_avatar(&mime_type, LOGO_BYTES.to_vec()) + .upload_avatar(&mime_type, logo_bytes) .await .map_err(|e| anyhow::anyhow!("Failed uploading avatar: {:?}", e))?; } diff --git a/src/bot/load_config.rs b/src/bot/load_config.rs index d56f6aa..6d9931d 100644 --- a/src/bot/load_config.rs +++ b/src/bot/load_config.rs @@ -4,6 +4,7 @@ use std::path::PathBuf; use anyhow::anyhow; use crate::agent::AgentPurpose; +use crate::entity::cfg::config::ConfigAvatar; pub use crate::entity::cfg::{Config, defaults as cfg_defaults, env as cfg_env}; @@ -37,6 +38,13 @@ pub fn load() -> anyhow::Result { config.user.encryption.recovery_reset_allowed = value.parse::()?; } cfg_env::BAIBOT_USER_NAME => config.user.name = value, + cfg_env::BAIBOT_USER_AVATAR_SOURCE => { + if let Some(ref mut avatar) = config.user.avatar { + avatar.source = value; + } else { + config.user.avatar = Some(ConfigAvatar { source: value }); + } + } cfg_env::BAIBOT_COMMAND_PREFIX => config.command_prefix = value, cfg_env::BAIBOT_ROOM_POST_JOIN_SELF_INTRODUCTION_ENABLED => { config.room.post_join_self_introduction_enabled = value.parse::()?; diff --git a/src/entity/cfg/config.rs b/src/entity/cfg/config.rs index 875e3ac..eb8674b 100644 --- a/src/entity/cfg/config.rs +++ b/src/entity/cfg/config.rs @@ -83,6 +83,12 @@ impl ConfigHomeserver { } } +#[derive(Debug, Serialize, Deserialize)] +pub struct ConfigAvatar { + #[serde(default)] + pub source: String, +} + #[derive(Debug, Serialize, Deserialize)] pub struct ConfigUser { pub mxid_localpart: String, @@ -93,6 +99,9 @@ pub struct ConfigUser { #[serde(default)] pub encryption: ConfigUserEncryption, + + #[serde(default)] + pub avatar: Option, } impl ConfigUser { diff --git a/src/entity/cfg/env.rs b/src/entity/cfg/env.rs index cfba7b6..5fcb809 100644 --- a/src/entity/cfg/env.rs +++ b/src/entity/cfg/env.rs @@ -6,6 +6,7 @@ pub const BAIBOT_HOMESERVER_URL: &str = "BAIBOT_HOMESERVER_URL"; pub const BAIBOT_USER_MXID_LOCALPART: &str = "BAIBOT_USER_MXID_LOCALPART"; pub const BAIBOT_USER_PASSWORD: &str = "BAIBOT_USER_PASSWORD"; pub const BAIBOT_USER_NAME: &str = "BAIBOT_USER_NAME"; +pub const BAIBOT_USER_AVATAR_SOURCE: &str = "BAIBOT_USER_AVATAR_SOURCE"; pub const BAIBOT_USER_ENCRYPTION_RECOVERY_PASSPHRASE: &str = "BAIBOT_USER_ENCRYPTION_RECOVERY_PASSPHRASE"; pub const BAIBOT_USER_ENCRYPTION_RECOVERY_RESET_ALLOWED: &str = diff --git a/src/entity/cfg/mod.rs b/src/entity/cfg/mod.rs index c4917b6..bb9e857 100644 --- a/src/entity/cfg/mod.rs +++ b/src/entity/cfg/mod.rs @@ -1,4 +1,4 @@ -mod config; +pub mod config; pub mod defaults; pub mod env;