Skip to content

Commit 2b7cdfe

Browse files
committed
feat: 'atuin script list' add tag filtering
- atuin script list -t <tag> - refactor database queries
1 parent 29576ac commit 2b7cdfe

File tree

2 files changed

+71
-52
lines changed

2 files changed

+71
-52
lines changed

crates/atuin-scripts/src/database.rs

Lines changed: 61 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use std::{path::Path, str::FromStr, time::Duration};
22

33
use atuin_common::utils;
4+
use sql_builder::quote;
45
use sqlx::{
56
Result, Row,
67
sqlite::{
@@ -141,38 +142,21 @@ impl Database {
141142
row.get("tag")
142143
}
143144

144-
#[allow(dead_code)]
145-
async fn load(&self, id: &str) -> Result<Option<Script>> {
146-
debug!("loading script item {}", id);
145+
pub async fn list_with_where_clauses(
146+
&self,
147+
additional_where_clause: Option<&str>,
148+
) -> Result<Vec<Script>> {
149+
debug!("listing scripts");
147150

148-
let res = sqlx::query("select * from scripts where id = ?1")
149-
.bind(id)
150-
.map(Self::query_script)
151-
.fetch_optional(&self.pool)
152-
.await?;
151+
let mut query_str = "select scripts.* from scripts".to_string();
153152

154-
// intentionally not joining, don't want to duplicate the script data in memory a whole bunch.
155-
if let Some(mut script) = res {
156-
let tags = sqlx::query("select tag from script_tags where script_id = ?1")
157-
.bind(id)
158-
.map(Self::query_script_tags)
159-
.fetch_all(&self.pool)
160-
.await?;
161-
162-
script.tags = tags;
163-
Ok(Some(script))
164-
} else {
165-
Ok(None)
153+
if let Some(additional_clause) = additional_where_clause {
154+
query_str.push_str(&format!(" {}", additional_clause));
166155
}
167-
}
168156

169-
pub async fn list(&self) -> Result<Vec<Script>> {
170-
debug!("listing scripts");
157+
let query = sqlx::query(&query_str);
171158

172-
let mut res = sqlx::query("select * from scripts")
173-
.map(Self::query_script)
174-
.fetch_all(&self.pool)
175-
.await?;
159+
let mut res = query.map(Self::query_script).fetch_all(&self.pool).await?;
176160

177161
// Fetch all the tags for each script
178162
for script in res.iter_mut() {
@@ -188,6 +172,34 @@ impl Database {
188172
Ok(res)
189173
}
190174

175+
pub async fn list(&self) -> Result<Vec<Script>> {
176+
self.list_with_where_clauses(None).await
177+
}
178+
179+
#[allow(dead_code)]
180+
async fn load(&self, id: &str) -> Result<Option<Script>> {
181+
debug!("loading script item {}", id);
182+
183+
let query_str = format!("where id = {}", quote(id));
184+
let scripts = self.list_with_where_clauses(Some(&query_str)).await?;
185+
Ok(scripts.into_iter().next())
186+
}
187+
188+
pub async fn get_by_name(&self, name: &str) -> Result<Option<Script>> {
189+
let query_str = format!("where name = {}", quote(name));
190+
let scripts = self.list_with_where_clauses(Some(&query_str)).await?;
191+
Ok(scripts.into_iter().next())
192+
}
193+
194+
pub async fn get_by_tag(&self, tag: &str) -> Result<Vec<Script>> {
195+
let query_str = format!(
196+
"join script_tags on scripts.id = script_tags.script_id where script_tags.tag = {}",
197+
quote(tag)
198+
);
199+
self.list_with_where_clauses(Some(&query_str)).await
200+
}
201+
202+
191203
pub async fn delete(&self, id: &str) -> Result<()> {
192204
debug!("deleting script {}", id);
193205

@@ -243,28 +255,6 @@ impl Database {
243255
Ok(())
244256
}
245257

246-
pub async fn get_by_name(&self, name: &str) -> Result<Option<Script>> {
247-
let res = sqlx::query("select * from scripts where name = ?1")
248-
.bind(name)
249-
.map(Self::query_script)
250-
.fetch_optional(&self.pool)
251-
.await?;
252-
253-
let script = if let Some(mut script) = res {
254-
let tags = sqlx::query("select tag from script_tags where script_id = ?1")
255-
.bind(script.id.to_string())
256-
.map(Self::query_script_tags)
257-
.fetch_all(&self.pool)
258-
.await?;
259-
260-
script.tags = tags;
261-
Some(script)
262-
} else {
263-
None
264-
};
265-
266-
Ok(script)
267-
}
268258
}
269259

270260
#[cfg(test)]
@@ -355,4 +345,26 @@ mod test {
355345
let loaded = db.list().await.unwrap();
356346
assert_eq!(loaded.len(), 0);
357347
}
348+
349+
#[tokio::test]
350+
async fn test_get_by_tag() {
351+
let db = Database::new("sqlite::memory:", 1.0).await.unwrap();
352+
353+
let script = Script::builder()
354+
.name("test name".to_string())
355+
.description("test description".to_string())
356+
.shebang("test shebang".to_string())
357+
.script("test script".to_string())
358+
.tags(vec!["tag1".to_string(), "tag2".to_string()])
359+
.build();
360+
361+
db.save(&script).await.unwrap();
362+
363+
let loaded = db.get_by_tag("tag1").await.unwrap();
364+
assert_eq!(loaded.len(), 1);
365+
assert_eq!(loaded[0], script);
366+
367+
let loaded = db.get_by_tag("tag3").await.unwrap();
368+
assert_eq!(loaded.len(), 0);
369+
}
358370
}

crates/atuin/src/command/client/scripts.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,10 @@ pub struct Run {
5454
}
5555

5656
#[derive(Parser, Debug)]
57-
pub struct List {}
57+
pub struct List {
58+
#[arg(short, long)]
59+
pub tag: Option<String>,
60+
}
5861

5962
#[derive(Parser, Debug)]
6063
pub struct Get {
@@ -356,10 +359,14 @@ impl Cmd {
356359

357360
async fn handle_list(
358361
_settings: &Settings,
359-
_list: List,
362+
list: List,
360363
script_db: atuin_scripts::database::Database,
361364
) -> Result<()> {
362-
let scripts = script_db.list().await?;
365+
let scripts = if let Some(tag) = list.tag {
366+
script_db.get_by_tag(&tag).await?
367+
} else {
368+
script_db.list().await?
369+
};
363370

364371
if scripts.is_empty() {
365372
println!("No scripts found");

0 commit comments

Comments
 (0)