Skip to content

Commit 175681d

Browse files
committed
Add new sub-command query to query SQLite databases
1 parent 7236a94 commit 175681d

6 files changed

Lines changed: 251 additions & 1 deletion

File tree

Cargo.lock

Lines changed: 48 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,4 @@ predicates = "3.1.0"
2727
pdf-extract = "0.7.4"
2828
base64 = "0.22.1"
2929
whoami = "1.6.0"
30+
rusqlite = { version = "0.32.1", features = ["bundled"] }

readme.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,9 @@ Commands:
173173
wl Use Wolfram Language and Mathematica development as the prompt context
174174
zig Use Zig development as the prompt context
175175
jq Use jq development as the prompt context
176+
   
177+
🗄️ DATABASE
178+
query Query a SQLite database using natural language
176179
help Print this message or the help of the given subcommand(s)
177180
178181
Arguments:

src/lib.rs

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1512,6 +1512,169 @@ pub async fn create_commits(
15121512
Ok(())
15131513
}
15141514

1515+
pub async fn query_database(
1516+
opts: &ExecOptions,
1517+
database_path: &str,
1518+
prompt_text: &str,
1519+
) -> Result<(), Box<dyn Error + Send + Sync>> {
1520+
use rusqlite::Connection;
1521+
1522+
// Open the database
1523+
let conn = Connection::open(database_path).map_err(|e| {
1524+
format!("Failed to open database '{}': {}", database_path, e)
1525+
})?;
1526+
1527+
// Get the database schema
1528+
let mut schema_parts: Vec<String> = Vec::new();
1529+
1530+
// Get all table names
1531+
let mut stmt = conn
1532+
.prepare("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
1533+
.map_err(|e| format!("Failed to query schema: {}", e))?;
1534+
1535+
let table_names: Vec<String> = stmt
1536+
.query_map([], |row| row.get(0))
1537+
.map_err(|e| format!("Failed to get table names: {}", e))?
1538+
.filter_map(|r| r.ok())
1539+
.collect();
1540+
1541+
// Get CREATE TABLE statement for each table
1542+
for table_name in &table_names {
1543+
let sql: String = conn
1544+
.query_row(
1545+
"SELECT sql FROM sqlite_master WHERE type='table' AND name=?1",
1546+
[table_name],
1547+
|row| row.get(0),
1548+
)
1549+
.map_err(|e| {
1550+
format!("Failed to get schema for table '{}': {}", table_name, e)
1551+
})?;
1552+
schema_parts.push(sql);
1553+
}
1554+
1555+
let schema = schema_parts.join("\n\n");
1556+
1557+
// Build the prompt for the LLM to generate SQL
1558+
let sql_prompt = format!(
1559+
"You are a SQLite expert. Given the following database schema, \
1560+
generate a SQL query to answer the user's question.\n\n\
1561+
IMPORTANT: Respond with ONLY the SQL query, no explanations, \
1562+
no markdown code blocks, no comments. Just the raw SQL.\n\n\
1563+
Schema:\n{schema}\n\n\
1564+
Question: {prompt_text}"
1565+
);
1566+
1567+
// Use OpenAI to generate the SQL query
1568+
let model = Model::Model(Provider::OpenAI, "gpt-4.1".to_string());
1569+
let secrets_path_str = get_secrets_path_str();
1570+
let full_config = get_full_config(&secrets_path_str)?;
1571+
let (_used_model, http_req) =
1572+
get_http_req(&Some(&model), &secrets_path_str, &full_config)?;
1573+
1574+
let mut raw_opts = opts.clone();
1575+
raw_opts.is_raw = true;
1576+
1577+
let req_body_obj = get_req_body_obj(&raw_opts, &http_req, &sql_prompt);
1578+
let resp = exec_request(&http_req, &req_body_obj).await?;
1579+
1580+
if !resp.status().is_success() {
1581+
let resp_json = resp.json::<Value>().await?;
1582+
let resp_formatted = serde_json::to_string_pretty(&resp_json).unwrap();
1583+
return Err(format!("Failed to generate SQL: {}", resp_formatted).into());
1584+
}
1585+
1586+
let ai_response = resp.json::<AiResponse>().await?;
1587+
let generated_sql = ai_response.choices[0]
1588+
.message
1589+
.content
1590+
.trim()
1591+
.trim_start_matches("```sql")
1592+
.trim_start_matches("```")
1593+
.trim_end_matches("```")
1594+
.trim()
1595+
.to_string();
1596+
1597+
if !opts.is_raw {
1598+
cprintln!("<bold>Generated SQL:</bold>");
1599+
println!("{}\n", generated_sql);
1600+
}
1601+
1602+
// Execute the generated SQL query
1603+
let mut stmt = conn
1604+
.prepare(&generated_sql)
1605+
.map_err(|e| format!("SQL error: {}", e))?;
1606+
1607+
let column_count = stmt.column_count();
1608+
let column_names: Vec<String> =
1609+
stmt.column_names().iter().map(|s| s.to_string()).collect();
1610+
1611+
// Execute and collect results
1612+
let rows_result = stmt.query_map([], |row| {
1613+
let mut row_values: Vec<String> = Vec::new();
1614+
for i in 0..column_count {
1615+
let value: rusqlite::types::Value = row.get(i)?;
1616+
let str_value = match value {
1617+
rusqlite::types::Value::Null => "NULL".to_string(),
1618+
rusqlite::types::Value::Integer(i) => i.to_string(),
1619+
rusqlite::types::Value::Real(f) => f.to_string(),
1620+
rusqlite::types::Value::Text(s) => s,
1621+
rusqlite::types::Value::Blob(b) => format!("<blob {} bytes>", b.len()),
1622+
};
1623+
row_values.push(str_value);
1624+
}
1625+
Ok(row_values)
1626+
});
1627+
1628+
let rows: Vec<Vec<String>> = rows_result
1629+
.map_err(|e| format!("Query execution error: {}", e))?
1630+
.filter_map(|r| r.ok())
1631+
.collect();
1632+
1633+
if !opts.is_raw {
1634+
cprintln!("<bold>Results ({} rows):</bold>", rows.len());
1635+
}
1636+
1637+
// Calculate column widths for pretty printing
1638+
let mut col_widths: Vec<usize> =
1639+
column_names.iter().map(|n| n.len()).collect();
1640+
for row in &rows {
1641+
for (i, val) in row.iter().enumerate() {
1642+
if val.len() > col_widths[i] {
1643+
col_widths[i] = val.len();
1644+
}
1645+
}
1646+
}
1647+
1648+
// Print header
1649+
let header: Vec<String> = column_names
1650+
.iter()
1651+
.enumerate()
1652+
.map(|(i, name)| format!("{:width$}", name, width = col_widths[i]))
1653+
.collect();
1654+
println!("{}", header.join(" | "));
1655+
1656+
// Print separator
1657+
let separator: Vec<String> =
1658+
col_widths.iter().map(|w| "-".repeat(*w)).collect();
1659+
println!("{}", separator.join("-+-"));
1660+
1661+
// Print rows
1662+
for row in &rows {
1663+
let formatted: Vec<String> = row
1664+
.iter()
1665+
.enumerate()
1666+
.map(|(i, val)| format!("{:width$}", val, width = col_widths[i]))
1667+
.collect();
1668+
println!("{}", formatted.join(" | "));
1669+
}
1670+
1671+
if !opts.is_raw {
1672+
println!();
1673+
}
1674+
1675+
Ok(())
1676+
}
1677+
15151678
#[cfg(test)]
15161679
mod tests {
15171680
use super::*;

src/main.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -945,6 +945,19 @@ async fn exec_with_args(args: Args, stdin: &str) {
945945
Commands::Jq { prompt } => {
946946
prompt_with_lang_cntxt(&opts, &cmd, prompt).await
947947
}
948+
949+
//////////////////////////////////////////////////////////////////////////
950+
//============================== DATABASE ================================
951+
//////////////////////////////////////////////////////////////////////////
952+
Commands::SectionDatabase {} => {}
953+
Commands::Query { database, prompt } => {
954+
if let Err(err) =
955+
cai::query_database(&opts, database, &prompt.join(" ")).await
956+
{
957+
eprintln!("Error querying database: {err}");
958+
std::process::exit(1);
959+
}
960+
}
948961
},
949962
};
950963
}

src/types.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,25 @@ for all supported model ids):"
544544
/// The prompt to send to the AI model
545545
prompt: Vec<String>,
546546
},
547+
548+
#[clap(
549+
about = color_print::cformat!(
550+
"\n<u><em><b!>{:<60}</b!></em></u>", "🗄️ DATABASE"
551+
),
552+
verbatim_doc_comment,
553+
name = "\u{00A0}\u{00A0}\u{00A0}" // Non-breaking space placeholder
554+
)]
555+
SectionDatabase {},
556+
557+
/// Query a SQLite database using natural language
558+
Query {
559+
/// Path to the SQLite database file
560+
#[clap(required = true)]
561+
database: String,
562+
/// The natural language query/question about the data
563+
#[clap(required = true)]
564+
prompt: Vec<String>,
565+
},
547566
}
548567

549568
impl std::fmt::Display for Commands {
@@ -644,6 +663,10 @@ impl Commands {
644663
Commands::Wl { .. } => Some("Wolfram Language"),
645664
Commands::Zig { .. } => Some("Zig"),
646665
Commands::Jq { .. } => Some("JQ"),
666+
667+
// Database
668+
Commands::SectionDatabase { .. } => None,
669+
Commands::Query { .. } => Some("Query"),
647670
}
648671
.map(|s| s.to_string())
649672
}

0 commit comments

Comments
 (0)