Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions marimo/_server/ai/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
SchemaTable,
VariableContext,
)
from marimo._types.ids import SessionId

FIM_PREFIX_TAG = "<|fim_prefix|>"
FIM_SUFFIX_TAG = "<|fim_suffix|>"
Expand Down Expand Up @@ -234,15 +235,24 @@ def _get_mode_intro_message(mode: CopilotMode) -> str:
)


def _get_session_info(session_id: SessionId) -> str:
return (
f"Current notebook session ID: {session_id}. "
"Use this session_id with tools that require it."
)


def get_chat_system_prompt(
*,
custom_rules: Optional[str],
context: Optional[AiCompletionContext],
include_other_code: str,
mode: CopilotMode,
session_id: SessionId,
) -> str:
system_prompt: str = f"""
{_get_mode_intro_message(mode)}
{_get_session_info(session_id)}

Your goal is to do one of the following two things:

Expand Down
2 changes: 2 additions & 0 deletions marimo/_server/api/endpoints/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ async def ai_chat(
"""
app_state = AppState(request)
app_state.require_current_session()
session_id = app_state.require_current_session_id()
config = app_state.app_config_manager.get_config(hide_secrets=False)
body = await parse_request(
request, cls=ChatRequest, allow_unknown_keys=True
Expand All @@ -223,6 +224,7 @@ async def ai_chat(
context=body.context,
include_other_code=body.include_other_code,
mode=ai_config.get("mode", "manual"),
session_id=session_id,
)

max_tokens = get_max_tokens(config)
Expand Down
7 changes: 7 additions & 0 deletions tests/_server/ai/snapshots/chat_system_prompts.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Your primary function is to help users create, analyze, and improve data science
- You do NOT have access to any external tools, plugins, or APIs.
- You may not perform any actions beyond generating text and code suggestions.

Current notebook session ID: s_test. Use this session_id with tools that require it.

Your goal is to do one of the following two things:

Expand Down Expand Up @@ -155,6 +156,7 @@ Your primary function is to help users create, analyze, and improve data science
- You do NOT have access to any external tools, plugins, or APIs.
- You may not perform any actions beyond generating text and code suggestions.

Current notebook session ID: s_test. Use this session_id with tools that require it.

Your goal is to do one of the following two things:

Expand Down Expand Up @@ -301,6 +303,7 @@ Your primary function is to help users create, analyze, and improve data science
- You do NOT have access to any external tools, plugins, or APIs.
- You may not perform any actions beyond generating text and code suggestions.

Current notebook session ID: s_test. Use this session_id with tools that require it.

Your goal is to do one of the following two things:

Expand Down Expand Up @@ -447,6 +450,7 @@ Your primary function is to help users create, analyze, and improve data science
- You do NOT have access to any external tools, plugins, or APIs.
- You may not perform any actions beyond generating text and code suggestions.

Current notebook session ID: s_test. Use this session_id with tools that require it.

Your goal is to do one of the following two things:

Expand Down Expand Up @@ -599,6 +603,7 @@ Your primary function is to help users create, analyze, and improve data science
- You do NOT have access to any external tools, plugins, or APIs.
- You may not perform any actions beyond generating text and code suggestions.

Current notebook session ID: s_test. Use this session_id with tools that require it.

Your goal is to do one of the following two things:

Expand Down Expand Up @@ -759,6 +764,7 @@ Your primary function is to help users create, analyze, and improve data science
- You do NOT have access to any external tools, plugins, or APIs.
- You may not perform any actions beyond generating text and code suggestions.

Current notebook session ID: s_test. Use this session_id with tools that require it.

Your goal is to do one of the following two things:

Expand Down Expand Up @@ -907,6 +913,7 @@ Your primary function is to help users create, analyze, and improve data science
- You do NOT have access to any external tools, plugins, or APIs.
- You may not perform any actions beyond generating text and code suggestions.

Current notebook session ID: s_test. Use this session_id with tools that require it.

Your goal is to do one of the following two things:

Expand Down
13 changes: 12 additions & 1 deletion tests/_server/ai/test_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
SchemaTable,
VariableContext,
)
from marimo._types.ids import SessionId
from tests.mocks import snapshotter

snapshot = snapshotter(__file__)
Expand Down Expand Up @@ -216,7 +217,11 @@ def test_chat_system_prompts():
result: str = ""
result += _header("no custom rules")
result += get_chat_system_prompt(
custom_rules=None, include_other_code="", context=None, mode="manual"
custom_rules=None,
include_other_code="",
context=None,
mode="manual",
session_id=SessionId("s_test"), # stable fake session id for snapshot
)

result += _header("with custom rules")
Expand All @@ -225,6 +230,7 @@ def test_chat_system_prompts():
include_other_code="",
context=None,
mode="manual",
session_id=SessionId("s_test"),
)

result += _header("with variables")
Expand All @@ -235,6 +241,7 @@ def test_chat_system_prompts():
variables=["var1", "var2"],
),
mode="manual",
session_id=SessionId("s_test"),
)

result += _header("with VariableContext objects")
Expand All @@ -256,6 +263,7 @@ def test_chat_system_prompts():
]
),
mode="manual",
session_id=SessionId("s_test"),
)

result += _header("with context")
Expand Down Expand Up @@ -291,6 +299,7 @@ def test_chat_system_prompts():
],
),
mode="manual",
session_id=SessionId("s_test"),
)

result += _header("with other code")
Expand All @@ -299,6 +308,7 @@ def test_chat_system_prompts():
include_other_code="import pandas as pd\nimport numpy as np\n",
context=None,
mode="manual",
session_id=SessionId("s_test"),
)

result += _header("kitchen sink")
Expand All @@ -324,6 +334,7 @@ def test_chat_system_prompts():
],
),
mode="manual",
session_id=SessionId("s_test"),
)

snapshot("chat_system_prompts.txt", result)
Expand Down
Loading