Skip to content

Commit ac27f57

Browse files
authored
support loading text prompts (#12)
1 parent 7d750c6 commit ac27f57

File tree

6 files changed

+147
-13
lines changed

6 files changed

+147
-13
lines changed

langchain_mcp_adapters/client.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from contextlib import AsyncExitStack
22
from types import TracebackType
3-
from typing import Literal, TypedDict, cast
3+
from typing import Any, Literal, Optional, TypedDict, cast
44

5+
from langchain_core.messages import AIMessage, HumanMessage
56
from langchain_core.tools import BaseTool
67
from mcp import ClientSession, StdioServerParameters
78
from mcp.client.sse import sse_client
89
from mcp.client.stdio import stdio_client
910

11+
from langchain_mcp_adapters.prompts import load_mcp_prompt
1012
from langchain_mcp_adapters.tools import load_mcp_tools
1113

1214
DEFAULT_ENCODING = "utf-8"
@@ -211,6 +213,13 @@ def get_tools(self) -> list[BaseTool]:
211213
all_tools.extend(server_tools)
212214
return all_tools
213215

216+
async def get_prompt(
217+
self, server_name: str, prompt_name: str, arguments: Optional[dict[str, Any]]
218+
) -> list[HumanMessage | AIMessage]:
219+
"""Get a prompt from a given MCP server."""
220+
session = self.sessions[server_name]
221+
return await load_mcp_prompt(session, prompt_name, arguments)
222+
214223
async def __aenter__(self) -> "MultiServerMCPClient":
215224
try:
216225
connections = self.connections or {}

langchain_mcp_adapters/prompts.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from typing import Any, Optional
2+
3+
from langchain_core.messages import AIMessage, HumanMessage
4+
from mcp import ClientSession
5+
from mcp.types import PromptMessage
6+
7+
8+
def convert_mcp_prompt_message_to_langchain_message(
9+
message: PromptMessage,
10+
) -> HumanMessage | AIMessage:
11+
"""Convert an MCP prompt message to a LangChain message.
12+
13+
Args:
14+
message: MCP prompt message to convert
15+
16+
Returns:
17+
a LangChain message
18+
"""
19+
if message.content.type == "text":
20+
if message.role == "user":
21+
return HumanMessage(content=message.content.text)
22+
elif message.role == "assistant":
23+
return AIMessage(content=message.content.text)
24+
else:
25+
raise ValueError(f"Unsupported prompt message role: {message.role}")
26+
27+
raise ValueError(f"Unsupported prompt message content type: {message.content.type}")
28+
29+
30+
async def load_mcp_prompt(
31+
session: ClientSession, name: str, arguments: Optional[dict[str, Any]] = None
32+
) -> list[HumanMessage | AIMessage]:
33+
"""Load MCP prompt and convert to LangChain messages."""
34+
response = await session.get_prompt(name, arguments)
35+
return [
36+
convert_mcp_prompt_message_to_langchain_message(message) for message in response.messages
37+
]

pyproject.toml

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ name = "langchain-mcp-adapters"
77
version = "0.0.4"
88
description = "Make Anthropic Model Context Protocol (MCP) tools compatible with LangChain and LangGraph agents."
99
authors = [
10-
{name = "Vadym Barda", email = "[email protected] "}
10+
{ name = "Vadym Barda", email = "[email protected]" },
1111
]
1212
readme = "README.md"
1313
requires-python = ">=3.10"
@@ -22,15 +22,14 @@ test = [
2222
"ruff>=0.9.4",
2323
"mypy>=1.8.0",
2424
"pytest-socket>=0.7.0",
25+
"pytest-asyncio>=0.25.0",
2526
"types-setuptools>=69.0.0",
2627
]
2728

2829
[tool.pytest.ini_options]
2930
minversion = "8.0"
3031
addopts = "-ra -q -v"
31-
testpaths = [
32-
"tests",
33-
]
32+
testpaths = ["tests"]
3433
python_files = ["test_*.py"]
3534
python_functions = ["test_*"]
3635

@@ -40,14 +39,14 @@ target-version = "py310"
4039

4140
[tool.ruff.lint]
4241
select = [
43-
"E", # pycodestyle errors
44-
"W", # pycodestyle warnings
45-
"F", # pyflakes
46-
"I", # isort
47-
"B", # flake8-bugbear
42+
"E", # pycodestyle errors
43+
"W", # pycodestyle warnings
44+
"F", # pyflakes
45+
"I", # isort
46+
"B", # flake8-bugbear
4847
]
4948
ignore = [
50-
"E501" # line-length
49+
"E501", # line-length
5150
]
5251

5352

@@ -57,4 +56,3 @@ warn_return_any = true
5756
warn_unused_configs = true
5857
disallow_untyped_defs = true
5958
check_untyped_defs = true
60-

tests/test_import.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
def test_import() -> None:
22
"""Test that the code can be imported"""
3-
from langchain_mcp_adapters import client, tools # noqa: F401
3+
from langchain_mcp_adapters import client, prompts, tools # noqa: F401

tests/test_prompts.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from unittest.mock import AsyncMock
2+
3+
import pytest
4+
from langchain_core.messages import AIMessage, HumanMessage
5+
from mcp.types import (
6+
EmbeddedResource,
7+
ImageContent,
8+
PromptMessage,
9+
TextContent,
10+
TextResourceContents,
11+
)
12+
13+
from langchain_mcp_adapters.prompts import (
14+
convert_mcp_prompt_message_to_langchain_message,
15+
load_mcp_prompt,
16+
)
17+
18+
19+
@pytest.mark.parametrize(
20+
"role,text,expected_cls",
21+
[
22+
("assistant", "Hello", AIMessage),
23+
("user", "Hello", HumanMessage),
24+
],
25+
)
26+
def test_convert_mcp_prompt_message_to_langchain_message_with_text_content(
27+
role: str, text: str, expected_cls: type
28+
):
29+
message = PromptMessage(role=role, content=TextContent(type="text", text=text))
30+
result = convert_mcp_prompt_message_to_langchain_message(message)
31+
assert isinstance(result, expected_cls)
32+
assert result.content == text
33+
34+
35+
@pytest.mark.parametrize("role", ["assistant", "user"])
36+
def test_convert_mcp_prompt_message_to_langchain_message_with_resource_content(role: str):
37+
message = PromptMessage(
38+
role=role,
39+
content=EmbeddedResource(
40+
type="resource",
41+
resource=TextResourceContents(
42+
uri="message://greeting", mimeType="text/plain", text="hi"
43+
),
44+
),
45+
)
46+
with pytest.raises(ValueError):
47+
convert_mcp_prompt_message_to_langchain_message(message)
48+
49+
50+
@pytest.mark.parametrize("role", ["assistant", "user"])
51+
def test_convert_mcp_prompt_message_to_langchain_message_with_image_content(role: str):
52+
message = PromptMessage(
53+
role=role, content=ImageContent(type="image", mimeType="image/png", data="base64data")
54+
)
55+
with pytest.raises(ValueError):
56+
convert_mcp_prompt_message_to_langchain_message(message)
57+
58+
59+
@pytest.mark.asyncio
60+
async def test_load_mcp_prompt():
61+
session = AsyncMock()
62+
session.get_prompt = AsyncMock(
63+
return_value=AsyncMock(
64+
messages=[
65+
PromptMessage(role="user", content=TextContent(type="text", text="Hello")),
66+
PromptMessage(role="assistant", content=TextContent(type="text", text="Hi")),
67+
]
68+
)
69+
)
70+
result = await load_mcp_prompt(session, "test_prompt")
71+
assert len(result) == 2
72+
assert isinstance(result[0], HumanMessage)
73+
assert result[0].content == "Hello"
74+
assert isinstance(result[1], AIMessage)
75+
assert result[1].content == "Hi"

uv.lock

Lines changed: 15 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)