2525from mbodied .agents .backends .serializer import Serializer
2626from mbodied .types .message import Message
2727from mbodied .types .sense .vision import Image
28+ from mbodied .types .tool import Tool , ToolCall
2829
2930ERRORS = (
3031 OpenAIRateLimitError ,
@@ -99,7 +100,7 @@ def __init__(
99100 aclient: Whether to use the asynchronous client.
100101 **kwargs: Additional keyword arguments.
101102 """
102- self .api_key = api_key or os .getenv ("OPENAI_API_KEY" ) or os . getenv ( "MBODI_API_KEY" )
103+ self .api_key = api_key or os .getenv ("OPENAI_API_KEY" )
103104 self .client = client
104105 if self .client is None :
105106 from openai import AsyncOpenAI , OpenAI
@@ -119,18 +120,26 @@ def __init__(
119120 on_backoff = lambda details : print (f"Backing off { details ['wait' ]:.1f} seconds after { details ['tries' ]} tries." ), # noqa
120121 )
121122 def predict (
122- self , message : Message , context : List [Message ] | None = None , model : Any | None = None , ** kwargs
123- ) -> str :
123+ self ,
124+ message : Message ,
125+ context : List [Message ] | None = None ,
126+ model : Any | None = None ,
127+ tools : List [Tool ] | None = None ,
128+ ** kwargs ,
129+ ) -> str | tuple [str , List [ToolCall ]]:
124130 """Create a completion based on the given message and context.
125131
126132 Args:
127133 message (Message): The message to process.
128134 context (Optional[List[Message]]): The context of messages.
129135 model (Optional[Any]): The model used for processing the messages.
136+ tools (Optional[List[Tool]]): The tools to make available for function calling.
130137 **kwargs: Additional keyword arguments.
131138
132139 Returns:
133- str: The result of the completion.
140+ str | tuple[str, List[ToolCall]]:
141+ When tools are not provided: Just the text response
142+ When tools are provided: A tuple of (text_response, tool_calls)
134143 """
135144 context = context or self .INITIAL_CONTEXT
136145 model = model or self .DEFAULT_MODEL
@@ -141,18 +150,35 @@ def predict(
141150 messages = serialized_messages ,
142151 temperature = 0 ,
143152 max_tokens = 1000 ,
153+ tools = tools ,
144154 ** kwargs ,
145155 )
156+ if tools :
157+ tool_calls = []
158+ if completion .choices [0 ].message .tool_calls :
159+ for tool_call in completion .choices [0 ].message .tool_calls :
160+ tool_calls .append (ToolCall .model_validate (tool_call ))
161+ return completion .choices [0 ].message .content , tool_calls
162+
146163 return completion .choices [0 ].message .content
147164
148- def stream (self , message : Message , context : List [Message ] = None , model : str = "gpt-4o" , ** kwargs ):
165+ def stream (
166+ self , message : Message , context : List [Message ] = None , model : str = "gpt-4o" , tools : List [Tool ] = None , ** kwargs
167+ ):
149168 """Streams a completion for the given messages using the OpenAI API standard.
150169
151170 Args:
152171 message: Message to be sent to the completion API.
153172 context: The context of the messages.
154173 model: The model to be used for the completion.
174+ tools: Optional list of tools (function calls) available to the model.
155175 **kwargs: Additional keyword arguments.
176+
177+ Yields:
178+ When tools is None:
179+ str: Content delta chunks
180+ When tools is provided:
181+ tuple[str, Any]: Tuples of (content_delta, tool_call_delta) where either may be None
156182 """
157183 model = model or self .DEFAULT_MODEL
158184 context = context or self .INITIAL_CONTEXT
@@ -162,19 +188,37 @@ def stream(self, message: Message, context: List[Message] = None, model: str = "
162188 model = model ,
163189 temperature = 0 ,
164190 stream = True ,
191+ tools = tools ,
165192 ** kwargs ,
166193 )
167- for chunk in stream :
168- yield chunk .choices [0 ].delta .content or ""
169194
170- async def astream (self , message : Message , context : List [Message ] = None , model : str = "gpt-4o" , ** kwargs ):
195+ if not tools :
196+ for chunk in stream :
197+ yield chunk .choices [0 ].delta .content or ""
198+ else :
199+ for chunk in stream :
200+ delta = chunk .choices [0 ].delta
201+ content = delta .content or ""
202+ tool_calls = delta .tool_calls
203+ yield content , tool_calls
204+
205+ async def astream (
206+ self , message : Message , context : List [Message ] = None , model : str = "gpt-4o" , tools : List [Tool ] = None , ** kwargs
207+ ):
171208 """Streams a completion asynchronously for the given messages using the OpenAI API standard.
172209
173210 Args:
174211 message: Message to be sent to the completion API.
175212 context: The context of the messages.
176213 model: The model to be used for the completion.
214+ tools: Optional list of tools (function calls) available to the model.
177215 **kwargs: Additional keyword arguments.
216+
217+ Yields:
218+ When tools is None:
219+ str: Content delta chunks
220+ When tools is provided:
221+ tuple[str, Any]: Tuples of (content_delta, tool_call_delta) where either may be None
178222 """
179223 if not hasattr (self , "aclient" ):
180224 raise AttributeError ("AsyncOpenAI client not initialized. Pass in aclient=True to the constructor." )
@@ -186,7 +230,16 @@ async def astream(self, message: Message, context: List[Message] = None, model:
186230 model = model ,
187231 temperature = 0 ,
188232 stream = True ,
233+ tools = tools ,
189234 ** kwargs ,
190235 )
191- async for chunk in stream :
192- yield chunk .choices [0 ].delta .content or ""
236+
237+ if not tools :
238+ async for chunk in stream :
239+ yield chunk .choices [0 ].delta .content or ""
240+ else :
241+ async for chunk in stream :
242+ delta = chunk .choices [0 ].delta
243+ content = delta .content or ""
244+ tool_calls = delta .tool_calls
245+ yield content , tool_calls
0 commit comments