Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions marimo/_server/ai/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
FIM_SUFFIX_TAG = "<|fim_suffix|>"
FIM_MIDDLE_TAG = "<|fim_middle|>"

language_rules = {
LANGUAGES: list[Language] = ["python", "sql", "markdown"]
language_rules: dict[Language, list[str]] = {
"python": [
"For matplotlib: use plt.gca() as the last expression instead of plt.show().",
"For plotly: return the figure object directly.",
Expand All @@ -33,6 +34,15 @@
}


language_rules_multiple_cells: dict[Language, list[str]] = {
"sql": [
'SQL cells start with df = mo.sql(f"""<your query>""") for DuckDB, or df = mo.sql(f"""<your query>""", engine=engine) for other SQL engines.',
"This will automatically display the result in the UI. You do not need to return the dataframe in the cell.",
"The SQL must use the syntax of the database engine specified in the `engine` variable. If no engine, then use duckdb syntax.",
]
}


def _format_schema_info(tables: Optional[list[SchemaTable]]) -> str:
"""Helper to format schema information from context"""
if not tables:
Expand Down Expand Up @@ -166,10 +176,13 @@ def get_refactor_or_insert_notebook_cell_system_prompt(

if support_multiple_cells:
# Add all language rules for multi-cell scenarios
for lang in language_rules:
if len(language_rules[lang]) > 0:
for lang in LANGUAGES:
language_rule = language_rules_multiple_cells.get(
lang, language_rules.get(lang, [])
)
if language_rule:
system_prompt += (
f"\n\n## Rules for {lang}:\n{_rules(language_rules[lang])}"
f"\n\n## Rules for {lang}:\n{_rules(language_rule)}"
)
elif language in language_rules and language_rules[language]:
system_prompt += (
Expand Down
4 changes: 3 additions & 1 deletion tests/_server/ai/snapshots/system_prompts.txt
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,9 @@ Separate logic into multiple cells to keep the code organized and readable.
7. If a variable is already defined, use another name, or make it private by adding an underscore at the beginning.

## Rules for sql:
1. The SQL must use duckdb syntax.
1. SQL cells start with df = mo.sql(f"""<your query>""") for DuckDB, or df = mo.sql(f"""<your query>""", engine=engine) for other SQL engines.
2. This will automatically display the result in the UI. You do not need to return the dataframe in the cell.
3. The SQL must use the syntax of the database engine specified in the `engine` variable. If no engine, then use duckdb syntax.

## Available variables from other cells:
- variable: `df`
Expand Down
Loading