11use std:: { path:: Path , str:: FromStr , time:: Duration } ;
22
33use atuin_common:: utils;
4+ use sql_builder:: quote;
45use sqlx:: {
56 Result , Row ,
67 sqlite:: {
@@ -141,38 +142,21 @@ impl Database {
141142 row. get ( "tag" )
142143 }
143144
144- #[ allow( dead_code) ]
145- async fn load ( & self , id : & str ) -> Result < Option < Script > > {
146- debug ! ( "loading script item {}" , id) ;
145+ pub async fn list_with_where_clauses (
146+ & self ,
147+ additional_where_clause : Option < & str > ,
148+ ) -> Result < Vec < Script > > {
149+ debug ! ( "listing scripts" ) ;
147150
148- let res = sqlx:: query ( "select * from scripts where id = ?1" )
149- . bind ( id)
150- . map ( Self :: query_script)
151- . fetch_optional ( & self . pool )
152- . await ?;
151+ let mut query_str = "select scripts.* from scripts" . to_string ( ) ;
153152
154- // intentionally not joining, don't want to duplicate the script data in memory a whole bunch.
155- if let Some ( mut script) = res {
156- let tags = sqlx:: query ( "select tag from script_tags where script_id = ?1" )
157- . bind ( id)
158- . map ( Self :: query_script_tags)
159- . fetch_all ( & self . pool )
160- . await ?;
161-
162- script. tags = tags;
163- Ok ( Some ( script) )
164- } else {
165- Ok ( None )
153+ if let Some ( additional_clause) = additional_where_clause {
154+ query_str. push_str ( & format ! ( " {}" , additional_clause) ) ;
166155 }
167- }
168156
169- pub async fn list ( & self ) -> Result < Vec < Script > > {
170- debug ! ( "listing scripts" ) ;
157+ let query = sqlx:: query ( & query_str) ;
171158
172- let mut res = sqlx:: query ( "select * from scripts" )
173- . map ( Self :: query_script)
174- . fetch_all ( & self . pool )
175- . await ?;
159+ let mut res = query. map ( Self :: query_script) . fetch_all ( & self . pool ) . await ?;
176160
177161 // Fetch all the tags for each script
178162 for script in res. iter_mut ( ) {
@@ -188,6 +172,33 @@ impl Database {
188172 Ok ( res)
189173 }
190174
175+ pub async fn list ( & self ) -> Result < Vec < Script > > {
176+ self . list_with_where_clauses ( None ) . await
177+ }
178+
179+ #[ allow( dead_code) ]
180+ async fn load ( & self , id : & str ) -> Result < Option < Script > > {
181+ debug ! ( "loading script item {}" , id) ;
182+
183+ let query_str = format ! ( "where id = {}" , quote( id) ) ;
184+ let scripts = self . list_with_where_clauses ( Some ( & query_str) ) . await ?;
185+ Ok ( scripts. into_iter ( ) . next ( ) )
186+ }
187+
188+ pub async fn get_by_name ( & self , name : & str ) -> Result < Option < Script > > {
189+ let query_str = format ! ( "where name = {}" , quote( name) ) ;
190+ let scripts = self . list_with_where_clauses ( Some ( & query_str) ) . await ?;
191+ Ok ( scripts. into_iter ( ) . next ( ) )
192+ }
193+
194+ pub async fn get_by_tag ( & self , tag : & str ) -> Result < Vec < Script > > {
195+ let query_str = format ! (
196+ "join script_tags on scripts.id = script_tags.script_id where script_tags.tag = {}" ,
197+ quote( tag)
198+ ) ;
199+ self . list_with_where_clauses ( Some ( & query_str) ) . await
200+ }
201+
191202 pub async fn delete ( & self , id : & str ) -> Result < ( ) > {
192203 debug ! ( "deleting script {}" , id) ;
193204
@@ -242,29 +253,6 @@ impl Database {
242253
243254 Ok ( ( ) )
244255 }
245-
246- pub async fn get_by_name ( & self , name : & str ) -> Result < Option < Script > > {
247- let res = sqlx:: query ( "select * from scripts where name = ?1" )
248- . bind ( name)
249- . map ( Self :: query_script)
250- . fetch_optional ( & self . pool )
251- . await ?;
252-
253- let script = if let Some ( mut script) = res {
254- let tags = sqlx:: query ( "select tag from script_tags where script_id = ?1" )
255- . bind ( script. id . to_string ( ) )
256- . map ( Self :: query_script_tags)
257- . fetch_all ( & self . pool )
258- . await ?;
259-
260- script. tags = tags;
261- Some ( script)
262- } else {
263- None
264- } ;
265-
266- Ok ( script)
267- }
268256}
269257
270258#[ cfg( test) ]
@@ -355,4 +343,26 @@ mod test {
355343 let loaded = db. list ( ) . await . unwrap ( ) ;
356344 assert_eq ! ( loaded. len( ) , 0 ) ;
357345 }
346+
347+ #[ tokio:: test]
348+ async fn test_get_by_tag ( ) {
349+ let db = Database :: new ( "sqlite::memory:" , 1.0 ) . await . unwrap ( ) ;
350+
351+ let script = Script :: builder ( )
352+ . name ( "test name" . to_string ( ) )
353+ . description ( "test description" . to_string ( ) )
354+ . shebang ( "test shebang" . to_string ( ) )
355+ . script ( "test script" . to_string ( ) )
356+ . tags ( vec ! [ "tag1" . to_string( ) , "tag2" . to_string( ) ] )
357+ . build ( ) ;
358+
359+ db. save ( & script) . await . unwrap ( ) ;
360+
361+ let loaded = db. get_by_tag ( "tag1" ) . await . unwrap ( ) ;
362+ assert_eq ! ( loaded. len( ) , 1 ) ;
363+ assert_eq ! ( loaded[ 0 ] , script) ;
364+
365+ let loaded = db. get_by_tag ( "tag3" ) . await . unwrap ( ) ;
366+ assert_eq ! ( loaded. len( ) , 0 ) ;
367+ }
358368}
0 commit comments