Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 60 additions & 50 deletions crates/atuin-scripts/src/database.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::{path::Path, str::FromStr, time::Duration};

use atuin_common::utils;
use sql_builder::quote;
use sqlx::{
Result, Row,
sqlite::{
Expand Down Expand Up @@ -141,38 +142,21 @@ impl Database {
row.get("tag")
}

#[allow(dead_code)]
async fn load(&self, id: &str) -> Result<Option<Script>> {
debug!("loading script item {}", id);
pub async fn list_with_where_clauses(
&self,
additional_where_clause: Option<&str>,
) -> Result<Vec<Script>> {
debug!("listing scripts");

let res = sqlx::query("select * from scripts where id = ?1")
.bind(id)
.map(Self::query_script)
.fetch_optional(&self.pool)
.await?;
let mut query_str = "select scripts.* from scripts".to_string();

// intentionally not joining, don't want to duplicate the script data in memory a whole bunch.
if let Some(mut script) = res {
let tags = sqlx::query("select tag from script_tags where script_id = ?1")
.bind(id)
.map(Self::query_script_tags)
.fetch_all(&self.pool)
.await?;

script.tags = tags;
Ok(Some(script))
} else {
Ok(None)
if let Some(additional_clause) = additional_where_clause {
query_str.push_str(&format!(" {}", additional_clause));
}
}

pub async fn list(&self) -> Result<Vec<Script>> {
debug!("listing scripts");
let query = sqlx::query(&query_str);

let mut res = sqlx::query("select * from scripts")
.map(Self::query_script)
.fetch_all(&self.pool)
.await?;
let mut res = query.map(Self::query_script).fetch_all(&self.pool).await?;

// Fetch all the tags for each script
for script in res.iter_mut() {
Expand All @@ -188,6 +172,33 @@ impl Database {
Ok(res)
}

pub async fn list(&self) -> Result<Vec<Script>> {
self.list_with_where_clauses(None).await
}

#[allow(dead_code)]
async fn load(&self, id: &str) -> Result<Option<Script>> {
debug!("loading script item {}", id);

let query_str = format!("where id = {}", quote(id));
let scripts = self.list_with_where_clauses(Some(&query_str)).await?;
Ok(scripts.into_iter().next())
}

pub async fn get_by_name(&self, name: &str) -> Result<Option<Script>> {
let query_str = format!("where name = {}", quote(name));
let scripts = self.list_with_where_clauses(Some(&query_str)).await?;
Ok(scripts.into_iter().next())
}

pub async fn get_by_tag(&self, tag: &str) -> Result<Vec<Script>> {
let query_str = format!(
"join script_tags on scripts.id = script_tags.script_id where script_tags.tag = {}",
quote(tag)
);
self.list_with_where_clauses(Some(&query_str)).await
}

pub async fn delete(&self, id: &str) -> Result<()> {
debug!("deleting script {}", id);

Expand Down Expand Up @@ -242,29 +253,6 @@ impl Database {

Ok(())
}

pub async fn get_by_name(&self, name: &str) -> Result<Option<Script>> {
let res = sqlx::query("select * from scripts where name = ?1")
.bind(name)
.map(Self::query_script)
.fetch_optional(&self.pool)
.await?;

let script = if let Some(mut script) = res {
let tags = sqlx::query("select tag from script_tags where script_id = ?1")
.bind(script.id.to_string())
.map(Self::query_script_tags)
.fetch_all(&self.pool)
.await?;

script.tags = tags;
Some(script)
} else {
None
};

Ok(script)
}
}

#[cfg(test)]
Expand Down Expand Up @@ -355,4 +343,26 @@ mod test {
let loaded = db.list().await.unwrap();
assert_eq!(loaded.len(), 0);
}

#[tokio::test]
async fn test_get_by_tag() {
let db = Database::new("sqlite::memory:", 1.0).await.unwrap();

let script = Script::builder()
.name("test name".to_string())
.description("test description".to_string())
.shebang("test shebang".to_string())
.script("test script".to_string())
.tags(vec!["tag1".to_string(), "tag2".to_string()])
.build();

db.save(&script).await.unwrap();

let loaded = db.get_by_tag("tag1").await.unwrap();
assert_eq!(loaded.len(), 1);
assert_eq!(loaded[0], script);

let loaded = db.get_by_tag("tag3").await.unwrap();
assert_eq!(loaded.len(), 0);
}
}
13 changes: 10 additions & 3 deletions crates/atuin/src/command/client/scripts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ pub struct Run {
}

#[derive(Parser, Debug)]
pub struct List {}
pub struct List {
#[arg(short, long)]
pub tag: Option<String>,
}

#[derive(Parser, Debug)]
pub struct Get {
Expand Down Expand Up @@ -356,10 +359,14 @@ impl Cmd {

async fn handle_list(
_settings: &Settings,
_list: List,
list: List,
script_db: atuin_scripts::database::Database,
) -> Result<()> {
let scripts = script_db.list().await?;
let scripts = if let Some(tag) = list.tag {
script_db.get_by_tag(&tag).await?
} else {
script_db.list().await?
};

if scripts.is_empty() {
println!("No scripts found");
Expand Down