|
| 1 | +import json |
| 2 | +import logging |
| 3 | +from typing import AsyncGenerator |
| 4 | +from uuid import uuid4 |
| 5 | + |
| 6 | + |
| 7 | +from magentic import ( |
| 8 | + AssistantMessage, |
| 9 | + AsyncStreamedStr, |
| 10 | + OpenaiChatModel, |
| 11 | + SystemMessage, |
| 12 | + UserMessage, |
| 13 | + chatprompt, |
| 14 | + prompt, |
| 15 | +) |
| 16 | + |
| 17 | +from .models import ( |
| 18 | + AgentQueryRequest, |
| 19 | + ArtifactSSE, |
| 20 | + ArtifactSSEData, |
| 21 | + StatusUpdateSSE, |
| 22 | + StatusUpdateSSEData, |
| 23 | + TaskStructure, |
| 24 | +) |
| 25 | +from .prompts import ( |
| 26 | + DO_I_NEED_TO_ALLOCATE_THE_PORTFOLIO_PROMPT, |
| 27 | + PARSE_USER_MESSAGE_TO_STRUCTURE_THE_TASK, |
| 28 | + SYSTEM_PROMPT, |
| 29 | +) |
| 30 | +from .utils import is_last_message, sanitize_message, generate_id |
| 31 | + |
| 32 | +from .portfolio import prepare_allocation, save_allocation |
| 33 | + |
| 34 | +logger = logging.getLogger(__name__) |
| 35 | + |
| 36 | + |
| 37 | +def need_to_allocate_portfolio(conversation: str) -> bool: |
| 38 | + """Determine if the user needs to allocate the portfolio right now.""" |
| 39 | + |
| 40 | + @prompt( |
| 41 | + DO_I_NEED_TO_ALLOCATE_THE_PORTFOLIO_PROMPT, |
| 42 | + model=OpenaiChatModel(model="gpt-4o-mini", temperature=0.0), |
| 43 | + ) |
| 44 | + def _need_to_allocate_portfolio(conversation: str) -> bool: ... |
| 45 | + |
| 46 | + # I'm using a while loop here for exception handling. This is to retry |
| 47 | + # the prompt if the LLM returns something that is not a boolean. |
| 48 | + attempt = 0 |
| 49 | + while attempt <= 5: |
| 50 | + try: |
| 51 | + need_to_allocate = _need_to_allocate_portfolio(conversation) |
| 52 | + if isinstance(need_to_allocate, bool): |
| 53 | + return need_to_allocate |
| 54 | + else: |
| 55 | + raise ValueError("Need to allocate portfolio is not a boolean") |
| 56 | + except Exception as e: |
| 57 | + attempt += 1 |
| 58 | + if attempt > 5: |
| 59 | + logger.error(f"Error parsing user message: {e}") |
| 60 | + raise e |
| 61 | + |
| 62 | + |
| 63 | +def get_task_structure(messages: str) -> dict: |
| 64 | + """Get the task structure from the user messages.""" |
| 65 | + |
| 66 | + @prompt( |
| 67 | + PARSE_USER_MESSAGE_TO_STRUCTURE_THE_TASK, |
| 68 | + model=OpenaiChatModel(model="gpt-4o", temperature=0.0), |
| 69 | + ) |
| 70 | + def parse_user_message(conversation: str) -> TaskStructure: ... |
| 71 | + |
| 72 | + # I'm using a while loop here for exception handling. This is to retry the |
| 73 | + # prompt if the LLM fails to return a an answer with the requires structure. |
| 74 | + attempt = 0 |
| 75 | + while attempt <= 5: |
| 76 | + try: |
| 77 | + task_structure = parse_user_message(messages) |
| 78 | + return task_structure |
| 79 | + except Exception as e: |
| 80 | + attempt += 1 |
| 81 | + if attempt > 5: |
| 82 | + logger.error(f"Error parsing user message: {e}") |
| 83 | + raise e |
| 84 | + |
| 85 | + |
| 86 | +async def execution_loop(request: AgentQueryRequest) -> AsyncGenerator[dict, None]: |
| 87 | + """Process the query and generate responses.""" |
| 88 | + |
| 89 | + chat_messages = [] |
| 90 | + for message in request.messages: |
| 91 | + if message.role == "ai": |
| 92 | + chat_messages.append( |
| 93 | + AssistantMessage(content=sanitize_message(message.content)) |
| 94 | + ) |
| 95 | + elif message.role == "human": |
| 96 | + user_message_content = sanitize_message(message.content) |
| 97 | + chat_messages.append(UserMessage(content=user_message_content)) |
| 98 | + if is_last_message(message, request.messages): |
| 99 | + |
| 100 | + # I intentionally am not using function calling in this example |
| 101 | + # because I want all the logic that is under the hood to be exposed |
| 102 | + # explicitly so that others can use this code as a reference to learn |
| 103 | + # what's going on under the hood and how server sent events |
| 104 | + # are being yielded. |
| 105 | + if need_to_allocate_portfolio(str(chat_messages)): |
| 106 | + yield StatusUpdateSSE( |
| 107 | + data=StatusUpdateSSEData( |
| 108 | + eventType="INFO", |
| 109 | + message="Starting asset basket allocation...\n" |
| 110 | + + "Fetching task structure...", |
| 111 | + ), |
| 112 | + ).model_dump() |
| 113 | + |
| 114 | + task_structure = get_task_structure(str(chat_messages)) |
| 115 | + |
| 116 | + yield StatusUpdateSSE( |
| 117 | + data=StatusUpdateSSEData( |
| 118 | + eventType="INFO", |
| 119 | + message="Task structure:", |
| 120 | + details=[task_structure.__pretty_dict__()], |
| 121 | + ), |
| 122 | + ).model_dump() |
| 123 | + |
| 124 | + yield StatusUpdateSSE( |
| 125 | + data=StatusUpdateSSEData( |
| 126 | + eventType="INFO", message="Fetching historical prices..." |
| 127 | + ), |
| 128 | + ).model_dump() |
| 129 | + |
| 130 | + task_dict = task_structure.model_dump() |
| 131 | + task_dict.pop("task") |
| 132 | + allocation = None |
| 133 | + try: |
| 134 | + allocation = prepare_allocation(**task_dict) |
| 135 | + |
| 136 | + except Exception as e: |
| 137 | + yield StatusUpdateSSE( |
| 138 | + event="copilotStatusUpdate", |
| 139 | + data=StatusUpdateSSEData( |
| 140 | + eventType="ERROR", |
| 141 | + message=f"Error preparing allocation. {str(e)}", |
| 142 | + ), |
| 143 | + ).model_dump() |
| 144 | + chat_messages.append( |
| 145 | + AssistantMessage( |
| 146 | + content=f"Error preparing allocation. {str(e)}" |
| 147 | + ) |
| 148 | + ) |
| 149 | + chat_messages.append(UserMessage(content="What should I do?")) |
| 150 | + |
| 151 | + if allocation is not None: |
| 152 | + try: |
| 153 | + yield StatusUpdateSSE( |
| 154 | + event="copilotStatusUpdate", |
| 155 | + data=StatusUpdateSSEData( |
| 156 | + eventType="INFO", |
| 157 | + message="Basket weights optimized. Saving allocation...", |
| 158 | + ), |
| 159 | + ).model_dump() |
| 160 | + |
| 161 | + allocation_id = save_allocation( |
| 162 | + allocation_id=generate_id(length=2), |
| 163 | + allocation_data=allocation.to_dict(orient="records"), |
| 164 | + ) |
| 165 | + |
| 166 | + yield StatusUpdateSSE( |
| 167 | + event="copilotStatusUpdate", |
| 168 | + data=StatusUpdateSSEData( |
| 169 | + eventType="INFO", |
| 170 | + message="Allocation saved successfully.", |
| 171 | + ), |
| 172 | + ).model_dump() |
| 173 | + |
| 174 | + yield ArtifactSSE( |
| 175 | + data=ArtifactSSEData( |
| 176 | + type="table", |
| 177 | + name="Allocation", |
| 178 | + description="Allocation of assets to the in the basket.", |
| 179 | + uuid=uuid4(), |
| 180 | + content=allocation.to_dict(orient="records"), |
| 181 | + ), |
| 182 | + ).model_dump() |
| 183 | + |
| 184 | + chat_messages.append( |
| 185 | + AssistantMessage( |
| 186 | + content=sanitize_message( |
| 187 | + f"Allocation created. Allocation id: is `{allocation_id}`. Allocation data is {allocation.to_markdown()}." |
| 188 | + ) |
| 189 | + ) |
| 190 | + ) |
| 191 | + chat_messages.append( |
| 192 | + UserMessage( |
| 193 | + content="Write short sub-paragraph summary reports on the allocation for each of the risk models." |
| 194 | + + "At the end of your message include the allocation id formatted as an inline code block." |
| 195 | + ) |
| 196 | + ) |
| 197 | + except Exception as e: |
| 198 | + yield StatusUpdateSSE( |
| 199 | + event="copilotStatusUpdate", |
| 200 | + data=StatusUpdateSSEData( |
| 201 | + eventType="ERROR", |
| 202 | + message=f"Error saving allocation. {str(e)}", |
| 203 | + ), |
| 204 | + ).model_dump() |
| 205 | + |
| 206 | + @chatprompt( |
| 207 | + SystemMessage(SYSTEM_PROMPT), |
| 208 | + *chat_messages, |
| 209 | + model=OpenaiChatModel(model="gpt-4o", temperature=0.7), |
| 210 | + ) |
| 211 | + async def _llm() -> AsyncStreamedStr | str: ... |
| 212 | + |
| 213 | + llm_result = await _llm() |
| 214 | + |
| 215 | + if isinstance(llm_result, str): |
| 216 | + yield { |
| 217 | + "event": "copilotMessageChunk", |
| 218 | + "data": json.dumps({"delta": llm_result}), |
| 219 | + } |
| 220 | + else: |
| 221 | + async for chunk in llm_result: |
| 222 | + yield {"event": "copilotMessageChunk", "data": json.dumps({"delta": chunk})} |
0 commit comments