Skip to content

Commit 8a464c5

Browse files
authored
Support Tool Calling with Language Agent (#110)
1 parent 60a0570 commit 8a464c5

8 files changed

Lines changed: 1050 additions & 64 deletions

File tree

mbodied/agents/backends/anthropic_backend.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import os
1516
from typing import Any, List
1617

1718
import anthropic
@@ -80,7 +81,7 @@ def __init__(self, api_key: str | None, client: anthropic.Anthropic | None = Non
8081
client: An optional client for the Anthropic service.
8182
kwargs: Additional keyword arguments.
8283
"""
83-
self.api_key = api_key
84+
self.api_key = api_key or os.getenv("ANTHROPIC_API_KEY")
8485
self.client = client
8586

8687
self.model = kwargs.pop("model", self.DEFAULT_MODEL)
@@ -117,13 +118,6 @@ def predict(
117118
)
118119
return completion.content[0].text
119120

120-
async def async_predict(
121-
self, message: Message, context: List[Message] | None = None, model: Any | None = None
122-
) -> str:
123-
"""Asynchronously predict the next message in the conversation."""
124-
# For now, we'll use the synchronous method since Anthropic doesn't provide an async API
125-
return self.predict(message, context, model)
126-
127121
def stream(
128122
self, message: Message, context: List[Message] = None, model: str = "claude-3-5-sonnet-20240620", **kwargs
129123
):

mbodied/agents/backends/gemini_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def __init__(
105105
client: An optional client for the Gemini service.
106106
**kwargs: Additional keyword arguments.
107107
"""
108-
self.api_key = api_key or os.getenv("GOOGLE_API_KEY") or os.getenv("MBODI_API_KEY")
108+
self.api_key = api_key or os.getenv("GOOGLE_API_KEY")
109109
self.client = client
110110

111111
self.model = kwargs.pop("model", self.DEFAULT_MODEL)

mbodied/agents/backends/openai_backend.py

Lines changed: 63 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from mbodied.agents.backends.serializer import Serializer
2626
from mbodied.types.message import Message
2727
from mbodied.types.sense.vision import Image
28+
from mbodied.types.tool import Tool, ToolCall
2829

2930
ERRORS = (
3031
OpenAIRateLimitError,
@@ -99,7 +100,7 @@ def __init__(
99100
aclient: Whether to use the asynchronous client.
100101
**kwargs: Additional keyword arguments.
101102
"""
102-
self.api_key = api_key or os.getenv("OPENAI_API_KEY") or os.getenv("MBODI_API_KEY")
103+
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
103104
self.client = client
104105
if self.client is None:
105106
from openai import AsyncOpenAI, OpenAI
@@ -119,18 +120,26 @@ def __init__(
119120
on_backoff=lambda details: print(f"Backing off {details['wait']:.1f} seconds after {details['tries']} tries."), # noqa
120121
)
121122
def predict(
122-
self, message: Message, context: List[Message] | None = None, model: Any | None = None, **kwargs
123-
) -> str:
123+
self,
124+
message: Message,
125+
context: List[Message] | None = None,
126+
model: Any | None = None,
127+
tools: List[Tool] | None = None,
128+
**kwargs,
129+
) -> str | tuple[str, List[ToolCall]]:
124130
"""Create a completion based on the given message and context.
125131
126132
Args:
127133
message (Message): The message to process.
128134
context (Optional[List[Message]]): The context of messages.
129135
model (Optional[Any]): The model used for processing the messages.
136+
tools (Optional[List[Tool]]): The tools to make available for function calling.
130137
**kwargs: Additional keyword arguments.
131138
132139
Returns:
133-
str: The result of the completion.
140+
str | tuple[str, List[ToolCall]]:
141+
When tools are not provided: Just the text response
142+
When tools are provided: A tuple of (text_response, tool_calls)
134143
"""
135144
context = context or self.INITIAL_CONTEXT
136145
model = model or self.DEFAULT_MODEL
@@ -141,18 +150,35 @@ def predict(
141150
messages=serialized_messages,
142151
temperature=0,
143152
max_tokens=1000,
153+
tools=tools,
144154
**kwargs,
145155
)
156+
if tools:
157+
tool_calls = []
158+
if completion.choices[0].message.tool_calls:
159+
for tool_call in completion.choices[0].message.tool_calls:
160+
tool_calls.append(ToolCall.model_validate(tool_call))
161+
return completion.choices[0].message.content, tool_calls
162+
146163
return completion.choices[0].message.content
147164

148-
def stream(self, message: Message, context: List[Message] = None, model: str = "gpt-4o", **kwargs):
165+
def stream(
166+
self, message: Message, context: List[Message] = None, model: str = "gpt-4o", tools: List[Tool] = None, **kwargs
167+
):
149168
"""Streams a completion for the given messages using the OpenAI API standard.
150169
151170
Args:
152171
message: Message to be sent to the completion API.
153172
context: The context of the messages.
154173
model: The model to be used for the completion.
174+
tools: Optional list of tools (function calls) available to the model.
155175
**kwargs: Additional keyword arguments.
176+
177+
Yields:
178+
When tools is None:
179+
str: Content delta chunks
180+
When tools is provided:
181+
tuple[str, Any]: Tuples of (content_delta, tool_call_delta) where either may be None
156182
"""
157183
model = model or self.DEFAULT_MODEL
158184
context = context or self.INITIAL_CONTEXT
@@ -162,19 +188,37 @@ def stream(self, message: Message, context: List[Message] = None, model: str = "
162188
model=model,
163189
temperature=0,
164190
stream=True,
191+
tools=tools,
165192
**kwargs,
166193
)
167-
for chunk in stream:
168-
yield chunk.choices[0].delta.content or ""
169194

170-
async def astream(self, message: Message, context: List[Message] = None, model: str = "gpt-4o", **kwargs):
195+
if not tools:
196+
for chunk in stream:
197+
yield chunk.choices[0].delta.content or ""
198+
else:
199+
for chunk in stream:
200+
delta = chunk.choices[0].delta
201+
content = delta.content or ""
202+
tool_calls = delta.tool_calls
203+
yield content, tool_calls
204+
205+
async def astream(
206+
self, message: Message, context: List[Message] = None, model: str = "gpt-4o", tools: List[Tool] = None, **kwargs
207+
):
171208
"""Streams a completion asynchronously for the given messages using the OpenAI API standard.
172209
173210
Args:
174211
message: Message to be sent to the completion API.
175212
context: The context of the messages.
176213
model: The model to be used for the completion.
214+
tools: Optional list of tools (function calls) available to the model.
177215
**kwargs: Additional keyword arguments.
216+
217+
Yields:
218+
When tools is None:
219+
str: Content delta chunks
220+
When tools is provided:
221+
tuple[str, Any]: Tuples of (content_delta, tool_call_delta) where either may be None
178222
"""
179223
if not hasattr(self, "aclient"):
180224
raise AttributeError("AsyncOpenAI client not initialized. Pass in aclient=True to the constructor.")
@@ -186,7 +230,16 @@ async def astream(self, message: Message, context: List[Message] = None, model:
186230
model=model,
187231
temperature=0,
188232
stream=True,
233+
tools=tools,
189234
**kwargs,
190235
)
191-
async for chunk in stream:
192-
yield chunk.choices[0].delta.content or ""
236+
237+
if not tools:
238+
async for chunk in stream:
239+
yield chunk.choices[0].delta.content or ""
240+
else:
241+
async for chunk in stream:
242+
delta = chunk.choices[0].delta
243+
content = delta.content or ""
244+
tool_calls = delta.tool_calls
245+
yield content, tool_calls

0 commit comments

Comments
 (0)