diff --git a/config/config.yaml b/config/config.yaml index dc4c4ea5ad..f547462bab 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -23,7 +23,7 @@ RPM: 10 #SPARK_URL : "ws://spark-api.xf-yun.com/v2.1/chat" #### if Anthropic -#Anthropic_API_KEY: "YOUR_API_KEY" +#ANTHROPIC_API_KEY: "YOUR_API_KEY" #### if AZURE, check https://github.com/openai/openai-cookbook/blob/main/examples/azure/chat.ipynb #### You can use ENGINE or DEPLOYMENT mode diff --git a/docs/ROADMAP.md b/docs/ROADMAP.md index afc9ff445e..3cb03f3740 100644 --- a/docs/ROADMAP.md +++ b/docs/ROADMAP.md @@ -30,10 +30,10 @@ To reach version v0.5, approximately 70% of the following tasks need to be compl 4. Complete the design and implementation of module breakdown 5. Support various modes of memory: clearly distinguish between long-term and short-term memory 6. Perfect the test role, and carry out necessary interactions with humans - 7. Allowing natural communication between roles (expected v0.5.0) + 7. ~~Allowing natural communication between roles~~ (v0.5.0) 8. Implement SkillManager and the process of incremental Skill learning (experimentation done with game agents) 9. Automatically get RPM and configure it by calling the corresponding openai page, so that each key does not need to be manually configured - 10. IMPORTANT: Support incremental development (expected v0.5.0) + 10. ~~IMPORTANT: Support incremental development~~ (v0.5.0) 3. Strategies 1. Support ReAct strategy (experimentation done with game agents) 2. Support CoT strategy (experimentation done with game agents) @@ -45,8 +45,8 @@ To reach version v0.5, approximately 70% of the following tasks need to be compl 2. Implementation: Knowledge search, supporting 10+ data formats 3. Implementation: Data EDA (expected v0.6.0) 4. Implementation: Review - 5. Implementation: Add Document (expected v0.5.0) - 6. Implementation: Delete Document (expected v0.5.0) + 5. ~~Implementation~~: Add Document (v0.5.0) + 6. ~~Implementation~~: Delete Document (v0.5.0) 7. Implementation: Self-training 8. ~~Implementation: DebugError~~ (v0.2.1) 9. Implementation: Generate reliable unit tests based on YAPI diff --git a/examples/agent_creator.py b/examples/agent_creator.py index 05417d24ae..26af8a287a 100644 --- a/examples/agent_creator.py +++ b/examples/agent_creator.py @@ -12,9 +12,8 @@ from metagpt.roles import Role from metagpt.schema import Message -with open(METAGPT_ROOT / "examples/build_customized_agent.py", "r") as f: - # use official example script to guide AgentCreator - MULTI_ACTION_AGENT_CODE_EXAMPLE = f.read() +EXAMPLE_CODE_FILE = METAGPT_ROOT / "examples/build_customized_agent.py" +MULTI_ACTION_AGENT_CODE_EXAMPLE = EXAMPLE_CODE_FILE.read_text() class CreateAgent(Action): @@ -50,8 +49,8 @@ def parse_code(rsp): match = re.search(pattern, rsp, re.DOTALL) code_text = match.group(1) if match else "" CONFIG.workspace_path.mkdir(parents=True, exist_ok=True) - with open(CONFIG.workspace_path / "agent_created_agent.py", "w") as f: - f.write(code_text) + new_file = CONFIG.workspace_path / "agent_created_agent.py" + new_file.write_text(code_text) return code_text diff --git a/metagpt/actions/__init__.py b/metagpt/actions/__init__.py index 79ff94b3ef..c34c72ed2f 100644 --- a/metagpt/actions/__init__.py +++ b/metagpt/actions/__init__.py @@ -13,7 +13,6 @@ from metagpt.actions.debug_error import DebugError from metagpt.actions.design_api import WriteDesign from metagpt.actions.design_api_review import DesignReview -from metagpt.actions.design_filenames import DesignFilenames from metagpt.actions.project_management import AssignTasks, WriteTasks from metagpt.actions.research import CollectLinks, WebBrowseAndSummarize, ConductResearch from metagpt.actions.run_code import RunCode @@ -33,7 +32,6 @@ class ActionType(Enum): WRITE_PRD_REVIEW = WritePRDReview WRITE_DESIGN = WriteDesign DESIGN_REVIEW = DesignReview - DESIGN_FILENAMES = DesignFilenames WRTIE_CODE = WriteCode WRITE_CODE_REVIEW = WriteCodeReview WRITE_TEST = WriteTest diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index 1534b1f4d5..a3a9c01957 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -6,20 +6,27 @@ @File : action.py """ +from __future__ import annotations + from abc import ABC from typing import Optional -from tenacity import retry, stop_after_attempt, wait_random_exponential - -from metagpt.actions.action_output import ActionOutput +from metagpt.actions.action_node import ActionNode from metagpt.llm import LLM -from metagpt.logs import logger -from metagpt.provider.postprecess.llm_output_postprecess import llm_output_postprecess -from metagpt.utils.common import OutputParser -from metagpt.utils.utils import general_after_log +from metagpt.schema import CodingContext, CodeSummarizeContext, TestingContext, RunCodeContext class Action(ABC): + """Action abstract class, requiring all inheritors to provide a series of standard capabilities""" + + name: str + llm: LLM + # FIXME: simplify context + context: dict | CodingContext | CodeSummarizeContext | TestingContext | RunCodeContext | str | None + prefix: str + desc: str + node: ActionNode | None + def __init__(self, name: str = "", context=None, llm: LLM = None): self.name: str = name if llm is None: @@ -27,22 +34,12 @@ def __init__(self, name: str = "", context=None, llm: LLM = None): self.llm = llm self.context = context self.prefix = "" # aask*时会加上prefix,作为system_message - self.profile = "" # FIXME: USELESS self.desc = "" # for skill manager - self.nodes = ... - - # Output, useless - # self.content = "" - # self.instruct_content = None - # self.env = None + self.node = None - # def set_env(self, env): - # self.env = env - - def set_prefix(self, prefix, profile): + def set_prefix(self, prefix): """Set prefix for later usage""" self.prefix = prefix - self.profile = profile return self def __str__(self): @@ -58,33 +55,6 @@ async def _aask(self, prompt: str, system_msgs: Optional[list[str]] = None) -> s system_msgs.append(self.prefix) return await self.llm.aask(prompt, system_msgs) - @retry( - wait=wait_random_exponential(min=1, max=60), - stop=stop_after_attempt(6), - after=general_after_log(logger), - ) - async def _aask_v1( - self, - prompt: str, - output_class_name: str, - output_data_mapping: dict, - system_msgs: Optional[list[str]] = None, - format="markdown", # compatible to original format - ) -> ActionOutput: - content = await self.llm.aask(prompt, system_msgs) - logger.debug(f"llm raw output:\n{content}") - output_class = ActionOutput.create_model_class(output_class_name, output_data_mapping) - - if format == "json": - parsed_data = llm_output_postprecess(output=content, schema=output_class.schema(), req_key="[/CONTENT]") - - else: # using markdown parser - parsed_data = OutputParser.parse_data_with_mapping(content, output_data_mapping) - - logger.debug(f"parsed_data:\n{parsed_data}") - instruct_content = output_class(**parsed_data) - return ActionOutput(content, instruct_content) - async def run(self, *args, **kwargs): """Run action""" raise NotImplementedError("The run method should be implemented in a subclass.") diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index fb7d621d8b..092dd57552 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -6,17 +6,15 @@ @File : action_node.py """ import json -import re -from typing import Any, Dict, List, Optional, Type +from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar from pydantic import BaseModel, create_model, root_validator, validator from tenacity import retry, stop_after_attempt, wait_random_exponential -from metagpt.actions import ActionOutput from metagpt.llm import BaseGPTAPI from metagpt.logs import logger -from metagpt.utils.common import OutputParser -from metagpt.utils.custom_decoder import CustomDecoder +from metagpt.provider.postprecess.llm_output_postprecess import llm_output_postprecess +from metagpt.utils.common import OutputParser, general_after_log CONSTRAINT = """ - Language: Please use the same language as the user input. @@ -50,8 +48,12 @@ def dict_to_markdown(d, prefix="-", postfix="\n"): return markdown_str -class ActionNode: +T = TypeVar("T") + + +class ActionNode(Generic[T]): """ActionNode is a tree of nodes.""" + mode: str # Action Context @@ -64,14 +66,21 @@ class ActionNode: expected_type: Type # such as str / int / float etc. # context: str # everything in the history. instruction: str # the instructions should be followed. - example: Any # example for In Context-Learning. + example: T # example for In Context-Learning. # Action Output content: str instruct_content: BaseModel - def __init__(self, key: str, expected_type: Type, instruction: str, example: str, content: str = "", - children: dict[str, "ActionNode"] = None): + def __init__( + self, + key: str, + expected_type: Type, + instruction: str, + example: T, + content: str = "", + children: dict[str, "ActionNode"] = None, + ): self.key = key self.expected_type = expected_type self.instruction = instruction @@ -103,22 +112,22 @@ def from_children(cls, key, nodes: List["ActionNode"]): obj.add_children(nodes) return obj - def get_children_mapping(self) -> Dict[str, Type]: + def get_children_mapping(self) -> Dict[str, Tuple[Type, Any]]: """获得子ActionNode的字典,以key索引""" return {k: (v.expected_type, ...) for k, v in self.children.items()} - def get_self_mapping(self) -> Dict[str, Type]: + def get_self_mapping(self) -> Dict[str, Tuple[Type, Any]]: """get self key: type mapping""" return {self.key: (self.expected_type, ...)} - def get_mapping(self, mode="children") -> Dict[str, Type]: + def get_mapping(self, mode="children") -> Dict[str, Tuple[Type, Any]]: """get key: type mapping under mode""" if mode == "children" or (mode == "auto" and self.children): return self.get_children_mapping() return self.get_self_mapping() @classmethod - def create_model_class(cls, class_name: str, mapping: Dict[str, Type]): + def create_model_class(cls, class_name: str, mapping: Dict[str, Tuple[Type, Any]]): """基于pydantic v1的模型动态生成,用来检验结果类型正确性""" new_class = create_model(class_name, **mapping) @@ -140,29 +149,6 @@ def check_missing_fields(values): new_class.__root_validator_check_missing_fields = classmethod(check_missing_fields) return new_class - @classmethod - def create_model_class_v2(cls, class_name: str, mapping: Dict[str, Type]): - """基于pydantic v2的模型动态生成,用来检验结果类型正确性,待验证""" - new_class = create_model(class_name, **mapping) - - @model_validator(mode="before") - def check_missing_fields(data): - required_fields = set(mapping.keys()) - missing_fields = required_fields - set(data.keys()) - if missing_fields: - raise ValueError(f"Missing fields: {missing_fields}") - return data - - @field_validator("*") - def check_name(v: Any, field: str) -> Any: - if field not in mapping.keys(): - raise ValueError(f"Unrecognized block: {field}") - return v - - new_class.__model_validator_check_missing_fields = classmethod(check_missing_fields) - new_class.__field_validator_check_name = classmethod(check_name) - return new_class - def create_children_class(self): """使用object内有的字段直接生成model_class""" class_name = f"{self.key}_AN" @@ -189,46 +175,46 @@ def to_dict(self, format_func=None, mode="auto") -> Dict: return node_dict # 遍历子节点并递归调用 to_dict 方法 - for child_key, child_node in self.children.items(): + for _, child_node in self.children.items(): node_dict.update(child_node.to_dict(format_func)) return node_dict - def compile_to(self, i: Dict, to) -> str: - if to == "json": + def compile_to(self, i: Dict, schema) -> str: + if schema == "json": return json.dumps(i, indent=4) - elif to == "markdown": + elif schema == "markdown": return dict_to_markdown(i) else: return str(i) - def tagging(self, text, to, tag="") -> str: + def tagging(self, text, schema, tag="") -> str: if not tag: return text - if to == "json": + if schema == "json": return f"[{tag}]\n" + text + f"\n[/{tag}]" else: return f"[{tag}]\n" + text + f"\n[/{tag}]" - def _compile_f(self, to, mode, tag, format_func) -> str: + def _compile_f(self, schema, mode, tag, format_func) -> str: nodes = self.to_dict(format_func=format_func, mode=mode) - text = self.compile_to(nodes, to) - return self.tagging(text, to, tag) + text = self.compile_to(nodes, schema) + return self.tagging(text, schema, tag) - def compile_instruction(self, to="raw", mode="children", tag="") -> str: + def compile_instruction(self, schema="raw", mode="children", tag="") -> str: """compile to raw/json/markdown template with all/root/children nodes""" format_func = lambda i: f"{i.expected_type} # {i.instruction}" - return self._compile_f(to, mode, tag, format_func) + return self._compile_f(schema, mode, tag, format_func) - def compile_example(self, to="raw", mode="children", tag="") -> str: + def compile_example(self, schema="raw", mode="children", tag="") -> str: """compile to raw/json/markdown examples with all/root/children nodes""" # 这里不能使用f-string,因为转译为str后再json.dumps会额外加上引号,无法作为有效的example # 错误示例:"File list": "['main.py', 'const.py', 'game.py']", 注意这里值不是list,而是str format_func = lambda i: i.example - return self._compile_f(to, mode, tag, format_func) + return self._compile_f(schema, mode, tag, format_func) - def compile(self, context, to="json", mode="children", template=SIMPLE_TEMPLATE) -> str: + def compile(self, context, schema="json", mode="children", template=SIMPLE_TEMPLATE) -> str: """ mode: all/root/children mode="children": 编译所有子节点为一个统一模板,包括instruction与example @@ -237,43 +223,40 @@ def compile(self, context, to="json", mode="children", template=SIMPLE_TEMPLATE) """ # FIXME: json instruction会带来格式问题,如:"Project name": "web_2048 # 项目名称使用下划线", - self.instruction = self.compile_instruction(to="markdown", mode=mode) - self.example = self.compile_example(to=to, tag="CONTENT", mode=mode) + # compile example暂时不支持markdown + self.instruction = self.compile_instruction(schema="markdown", mode=mode) + self.example = self.compile_example(schema=schema, tag="CONTENT", mode=mode) prompt = template.format( context=context, example=self.example, instruction=self.instruction, constraint=CONSTRAINT ) return prompt - @retry(wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(6)) + @retry( + wait=wait_random_exponential(min=1, max=20), + stop=stop_after_attempt(6), + after=general_after_log(logger), + ) async def _aask_v1( self, prompt: str, output_class_name: str, output_data_mapping: dict, system_msgs: Optional[list[str]] = None, - format="markdown", # compatible to original format - ) -> ActionOutput: + schema="markdown", # compatible to original format + ) -> (str, BaseModel): + """Use ActionOutput to wrap the output of aask""" content = await self.llm.aask(prompt, system_msgs) - logger.debug(content) - output_class = ActionOutput.create_model_class(output_class_name, output_data_mapping) - - if format == "json": - pattern = r"\[CONTENT\](\s*\{.*?\}\s*)\[/CONTENT\]" - matches = re.findall(pattern, content, re.DOTALL) - - for match in matches: - if match: - content = match - break - - parsed_data = CustomDecoder(strict=False).decode(content) + logger.debug(f"llm raw output:\n{content}") + output_class = self.create_model_class(output_class_name, output_data_mapping) + if schema == "json": + parsed_data = llm_output_postprecess(output=content, schema=output_class.schema(), req_key="[/CONTENT]") else: # using markdown parser parsed_data = OutputParser.parse_data_with_mapping(content, output_data_mapping) - logger.debug(parsed_data) + logger.debug(f"parsed_data:\n{parsed_data}") instruct_content = output_class(**parsed_data) - return ActionOutput(content, instruct_content) + return content, instruct_content def get(self, key): return self.instruct_content.dict()[key] @@ -289,22 +272,22 @@ def set_llm(self, llm): def set_context(self, context): self.set_recursive("context", context) - async def simple_fill(self, to, mode): - prompt = self.compile(context=self.context, to=to, mode=mode) + async def simple_fill(self, schema, mode): + prompt = self.compile(context=self.context, schema=schema, mode=mode) mapping = self.get_mapping(mode) class_name = f"{self.key}_AN" - output = await self._aask_v1(prompt, class_name, mapping, format=to) - self.content = output.content - self.instruct_content = output.instruct_content + content, scontent = await self._aask_v1(prompt, class_name, mapping, schema=schema) + self.content = content + self.instruct_content = scontent return self - async def fill(self, context, llm, to="json", mode="auto", strgy="simple"): + async def fill(self, context, llm, schema="json", mode="auto", strgy="simple"): """Fill the node(s) with mode. :param context: Everything we should know when filling node. :param llm: Large Language Model with pre-defined system message. - :param to: json/markdown, determine example and output format. + :param schema: json/markdown, determine example and output format. - json: it's easy to open source LLM with json format - markdown: when generating code, markdown is always better :param mode: auto/children/root @@ -320,12 +303,12 @@ async def fill(self, context, llm, to="json", mode="auto", strgy="simple"): self.set_context(context) if strgy == "simple": - return await self.simple_fill(to, mode) + return await self.simple_fill(schema, mode) elif strgy == "complex": # 这里隐式假设了拥有children tmp = {} for _, i in self.children.items(): - child = await i.simple_fill(to, mode) + child = await i.simple_fill(schema, mode) tmp.update(child.instruct_content.dict()) cls = self.create_children_class() self.instruct_content = cls(**tmp) diff --git a/metagpt/actions/action_output.py b/metagpt/actions/action_output.py index 25326d43b4..6be8dac50e 100644 --- a/metagpt/actions/action_output.py +++ b/metagpt/actions/action_output.py @@ -6,9 +6,7 @@ @File : action_output """ -from typing import Dict, Type - -from pydantic import BaseModel, create_model, root_validator, validator +from pydantic import BaseModel class ActionOutput: @@ -18,25 +16,3 @@ class ActionOutput: def __init__(self, content: str, instruct_content: BaseModel): self.content = content self.instruct_content = instruct_content - - @classmethod - def create_model_class(cls, class_name: str, mapping: Dict[str, Type]): - new_class = create_model(class_name, **mapping) - - @validator("*", allow_reuse=True) - def check_name(v, field): - if field.name not in mapping.keys(): - raise ValueError(f"Unrecognized block: {field.name}") - return v - - @root_validator(pre=True, allow_reuse=True) - def check_missing_fields(values): - required_fields = set(mapping.keys()) - missing_fields = required_fields - set(values.keys()) - if missing_fields: - raise ValueError(f"Missing fields: {missing_fields}") - return values - - new_class.__validator_check_name = classmethod(check_name) - new_class.__root_validator_check_missing_fields = classmethod(check_missing_fields) - return new_class diff --git a/metagpt/actions/analyze_dep_libs.py b/metagpt/actions/analyze_dep_libs.py deleted file mode 100644 index 53d40200aa..0000000000 --- a/metagpt/actions/analyze_dep_libs.py +++ /dev/null @@ -1,37 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -@Time : 2023/5/19 12:01 -@Author : alexanderwu -@File : analyze_dep_libs.py -""" - -from metagpt.actions import Action - -PROMPT = """You are an AI developer, trying to write a program that generates code for users based on their intentions. - -For the user's prompt: - ---- -The API is: {prompt} ---- - -We decide the generated files are: {filepaths_string} - -Now that we have a file list, we need to understand the shared dependencies they have. -Please list and briefly describe the shared contents between the files we are generating, including exported variables, -data patterns, id names of all DOM elements that javascript functions will use, message names and function names. -Focus only on the names of shared dependencies, do not add any other explanations. -""" - - -class AnalyzeDepLibs(Action): - def __init__(self, name, context=None, llm=None): - super().__init__(name, context, llm) - self.desc = "Analyze the runtime dependencies of the program based on the context" - - async def run(self, requirement, filepaths_string): - # prompt = f"Below is the product requirement document (PRD):\n\n{prd}\n\n{PROMPT}" - prompt = PROMPT.format(prompt=requirement, filepaths_string=filepaths_string) - design_filenames = await self._aask(prompt) - return design_filenames diff --git a/metagpt/actions/design_api.py b/metagpt/actions/design_api.py index 5a5f52de7e..548725fdec 100644 --- a/metagpt/actions/design_api.py +++ b/metagpt/actions/design_api.py @@ -50,7 +50,7 @@ def __init__(self, name, context=None, llm=None): "clearly and in detail." ) - async def run(self, with_messages, format=CONFIG.prompt_format): + async def run(self, with_messages, schema=CONFIG.prompt_schema): # Use `git diff` to identify which PRD documents have been modified in the `docs/prds` directory. prds_file_repo = CONFIG.git_repo.new_file_repository(PRDS_FILE_REPO) changed_prds = prds_file_repo.changed_files @@ -80,13 +80,13 @@ async def run(self, with_messages, format=CONFIG.prompt_format): # leaving room for global optimization in subsequent steps. return ActionOutput(content=changed_files.json(), instruct_content=changed_files) - async def _new_system_design(self, context, format=CONFIG.prompt_format): - node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, to=format) + async def _new_system_design(self, context, schema=CONFIG.prompt_schema): + node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, schema=schema) return node - async def _merge(self, prd_doc, system_design_doc, format=CONFIG.prompt_format): + async def _merge(self, prd_doc, system_design_doc, schema=CONFIG.prompt_schema): context = NEW_REQ_TEMPLATE.format(old_design=system_design_doc.content, context=prd_doc.content) - node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, to=format) + node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, schema=schema) system_design_doc.content = node.instruct_content.json(ensure_ascii=False) return system_design_doc diff --git a/metagpt/actions/design_filenames.py b/metagpt/actions/design_filenames.py deleted file mode 100644 index ffa171d7b5..0000000000 --- a/metagpt/actions/design_filenames.py +++ /dev/null @@ -1,30 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -@Time : 2023/5/19 11:50 -@Author : alexanderwu -@File : design_filenames.py -""" -from metagpt.actions import Action -from metagpt.logs import logger - -PROMPT = """You are an AI developer, trying to write a program that generates code for users based on their intentions. -When given their intentions, provide a complete and exhaustive list of file paths needed to write the program for the user. -Only list the file paths you will write and return them as a Python string list. -Do not add any other explanations, just return a Python string list.""" - - -class DesignFilenames(Action): - def __init__(self, name, context=None, llm=None): - super().__init__(name, context, llm) - self.desc = ( - "Based on the PRD, consider system design, and carry out the basic design of the corresponding " - "APIs, data structures, and database tables. Please give your design, feedback clearly and in detail." - ) - - async def run(self, prd): - prompt = f"The following is the Product Requirement Document (PRD):\n\n{prd}\n\n{PROMPT}" - design_filenames = await self._aask(prompt) - logger.debug(prompt) - logger.debug(design_filenames) - return design_filenames diff --git a/metagpt/actions/detail_mining.py b/metagpt/actions/detail_mining.py deleted file mode 100644 index 5afcf52c65..0000000000 --- a/metagpt/actions/detail_mining.py +++ /dev/null @@ -1,51 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -@Time : 2023/9/12 17:45 -@Author : fisherdeng -@File : detail_mining.py -""" -from metagpt.actions import Action, ActionOutput - -PROMPT_TEMPLATE = """ -##TOPIC -{topic} - -##RECORD -{record} - -##Format example -{format_example} ------ - -Task: Refer to the "##TOPIC" (discussion objectives) and "##RECORD" (discussion records) to further inquire about the details that interest you, within a word limit of 150 words. -Special Note 1: Your intention is solely to ask questions without endorsing or negating any individual's viewpoints. -Special Note 2: This output should only include the topic "##OUTPUT". Do not add, remove, or modify the topic. Begin the output with '##OUTPUT', followed by an immediate line break, and then proceed to provide the content in the specified format as outlined in the "##Format example" section. -Special Note 3: The output should be in the same language as the input. -""" -FORMAT_EXAMPLE = """ - -## - -##OUTPUT -...(Please provide the specific details you would like to inquire about here.) - -## - -## -""" -OUTPUT_MAPPING = { - "OUTPUT": (str, ...), -} - - -class DetailMining(Action): - """This class allows LLM to further mine noteworthy details based on specific "##TOPIC"(discussion topic) and "##RECORD" (discussion records), thereby deepening the discussion.""" - - def __init__(self, name="", context=None, llm=None): - super().__init__(name, context, llm) - - async def run(self, topic, record) -> ActionOutput: - prompt = PROMPT_TEMPLATE.format(topic=topic, record=record, format_example=FORMAT_EXAMPLE) - rsp = await self._aask_v1(prompt, "detail_mining", OUTPUT_MAPPING) - return rsp diff --git a/metagpt/actions/generate_questions.py b/metagpt/actions/generate_questions.py new file mode 100644 index 0000000000..c38c463bcc --- /dev/null +++ b/metagpt/actions/generate_questions.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/9/12 17:45 +@Author : fisherdeng +@File : generate_questions.py +""" +from metagpt.actions import Action +from metagpt.actions.action_node import ActionNode + +QUESTIONS = ActionNode( + key="Questions", + expected_type=list[str], + instruction="Task: Refer to the context to further inquire about the details that interest you, within a word limit" + " of 150 words. Please provide the specific details you would like to inquire about here", + example=["1. What ...", "2. How ...", "3. ..."], +) + + +class GenerateQuestions(Action): + """This class allows LLM to further mine noteworthy details based on specific "##TOPIC"(discussion topic) and + "##RECORD" (discussion records), thereby deepening the discussion.""" + + async def run(self, context): + return await QUESTIONS.fill(context=context, llm=self.llm) diff --git a/metagpt/actions/prepare_interview.py b/metagpt/actions/prepare_interview.py index b2704616e0..7ed42d5902 100644 --- a/metagpt/actions/prepare_interview.py +++ b/metagpt/actions/prepare_interview.py @@ -6,35 +6,18 @@ @File : prepare_interview.py """ from metagpt.actions import Action +from metagpt.actions.action_node import ActionNode -PROMPT_TEMPLATE = """ -# Context -{context} - -## Format example ---- -Q1: question 1 here -References: - - point 1 - - point 2 - -Q2: question 2 here... ---- - ------ -Role: You are an interviewer of our company who is well-knonwn in frontend or backend develop; +QUESTIONS = ActionNode( + key="Questions", + expected_type=list[str], + instruction="""Role: You are an interviewer of our company who is well-knonwn in frontend or backend develop; Requirement: Provide a list of questions for the interviewer to ask the interviewee, by reading the resume of the interviewee in the context. -Attention: Provide as markdown block as the format above, at least 10 questions. -""" - -# prepare for a interview +Attention: Provide as markdown block as the format above, at least 10 questions.""", + example=["1. What ...", "2. How ..."], +) class PrepareInterview(Action): - def __init__(self, name, context=None, llm=None): - super().__init__(name, context, llm) - async def run(self, context): - prompt = PROMPT_TEMPLATE.format(context=context) - question_list = await self._aask_v1(prompt) - return question_list + return await QUESTIONS.fill(context=context, llm=self.llm) diff --git a/metagpt/actions/project_management.py b/metagpt/actions/project_management.py index 1f14e79447..fe2c8d537c 100644 --- a/metagpt/actions/project_management.py +++ b/metagpt/actions/project_management.py @@ -42,7 +42,7 @@ class WriteTasks(Action): def __init__(self, name="CreateTasks", context=None, llm=None): super().__init__(name, context, llm) - async def run(self, with_messages, format=CONFIG.prompt_format): + async def run(self, with_messages, schema=CONFIG.prompt_schema): system_design_file_repo = CONFIG.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO) changed_system_designs = system_design_file_repo.changed_files @@ -89,16 +89,16 @@ async def _update_tasks(self, filename, system_design_file_repo, tasks_file_repo await self._save_pdf(task_doc=task_doc) return task_doc - async def _run_new_tasks(self, context, format=CONFIG.prompt_format): - node = await PM_NODE.fill(context, self.llm, format) + async def _run_new_tasks(self, context, schema=CONFIG.prompt_schema): + node = await PM_NODE.fill(context, self.llm, schema) # prompt_template, format_example = get_template(templates, format) # prompt = prompt_template.format(context=context, format_example=format_example) # rsp = await self._aask_v1(prompt, "task", OUTPUT_MAPPING, format=format) return node - async def _merge(self, system_design_doc, task_doc, format=CONFIG.prompt_format) -> Document: + async def _merge(self, system_design_doc, task_doc, schema=CONFIG.prompt_schema) -> Document: context = NEW_REQ_TEMPLATE.format(context=system_design_doc.content, old_tasks=task_doc.content) - node = await PM_NODE.fill(context, self.llm, format) + node = await PM_NODE.fill(context, self.llm, schema) task_doc.content = node.instruct_content.json(ensure_ascii=False) return task_doc diff --git a/metagpt/actions/project_management_an.py b/metagpt/actions/project_management_an.py index 970cb0594f..6208c1051f 100644 --- a/metagpt/actions/project_management_an.py +++ b/metagpt/actions/project_management_an.py @@ -44,7 +44,7 @@ key="Full API spec", expected_type=str, instruction="Describe all APIs using OpenAPI 3.0 spec that may be used by both frontend and backend. If front-end " - "and back-end communication is not required, leave it blank.", + "and back-end communication is not required, leave it blank.", example="openapi: 3.0.0 ...", ) diff --git a/metagpt/actions/run_code.py b/metagpt/actions/run_code.py index fa13a09808..1b9fd252fe 100644 --- a/metagpt/actions/run_code.py +++ b/metagpt/actions/run_code.py @@ -16,13 +16,13 @@ class. """ import subprocess -import traceback from typing import Tuple from metagpt.actions.action import Action from metagpt.config import CONFIG from metagpt.logs import logger from metagpt.schema import RunCodeResult +from metagpt.utils.exceptions import handle_exception PROMPT_TEMPLATE = """ Role: You are a senior development and qa engineer, your role is summarize the code running result. @@ -78,15 +78,12 @@ def __init__(self, name="RunCode", context=None, llm=None): super().__init__(name, context, llm) @classmethod + @handle_exception async def run_text(cls, code) -> Tuple[str, str]: - try: - # We will document_store the result in this dictionary - namespace = {} - exec(code, namespace) - return namespace.get("result", ""), "" - except Exception: - # If there is an error in the code, return the error message - return "", traceback.format_exc() + # We will document_store the result in this dictionary + namespace = {} + exec(code, namespace) + return namespace.get("result", ""), "" @classmethod async def run_script(cls, working_directory, additional_python_paths=[], command=[]) -> Tuple[str, str]: @@ -145,18 +142,17 @@ async def run(self, *args, **kwargs) -> RunCodeResult: rsp = await self._aask(prompt) return RunCodeResult(summary=rsp, stdout=outs, stderr=errs) + @staticmethod + @handle_exception(exception_type=subprocess.CalledProcessError) + def _install_via_subprocess(cmd, check, cwd, env): + return subprocess.run(cmd, check=check, cwd=cwd, env=env) + @staticmethod def _install_dependencies(working_directory, env): install_command = ["python", "-m", "pip", "install", "-r", "requirements.txt"] logger.info(" ".join(install_command)) - try: - subprocess.run(install_command, check=True, cwd=working_directory, env=env) - except subprocess.CalledProcessError as e: - logger.warning(f"{e}") + RunCode._install_via_subprocess(install_command, check=True, cwd=working_directory, env=env) install_pytest_command = ["python", "-m", "pip", "install", "pytest"] logger.info(" ".join(install_pytest_command)) - try: - subprocess.run(install_pytest_command, check=True, cwd=working_directory, env=env) - except subprocess.CalledProcessError as e: - logger.warning(f"{e}") + RunCode._install_via_subprocess(install_pytest_command, check=True, cwd=working_directory, env=env) diff --git a/metagpt/actions/search_and_summarize.py b/metagpt/actions/search_and_summarize.py index 5e4cdaea03..a1d81bc653 100644 --- a/metagpt/actions/search_and_summarize.py +++ b/metagpt/actions/search_and_summarize.py @@ -130,8 +130,7 @@ async def run(self, context: list[Message], system_text=SEARCH_AND_SUMMARIZE_SYS system_prompt = [system_text] prompt = SEARCH_AND_SUMMARIZE_PROMPT.format( - # PREFIX = self.prefix, - ROLE=self.profile, + ROLE=self.prefix, CONTEXT=rsp, QUERY_HISTORY="\n".join([str(i) for i in context[:-1]]), QUERY=str(context[-1]), diff --git a/metagpt/actions/write_code_review.py b/metagpt/actions/write_code_review.py index 4b3e9aece6..365c87063f 100644 --- a/metagpt/actions/write_code_review.py +++ b/metagpt/actions/write_code_review.py @@ -154,11 +154,15 @@ async def run(self, *args, **kwargs) -> CodingContext: code=iterative_code, filename=self.context.code_doc.filename, ) - cr_prompt = EXAMPLE_AND_INSTRUCTION.format(format_example=format_example, ) + cr_prompt = EXAMPLE_AND_INSTRUCTION.format( + format_example=format_example, + ) logger.info( f"Code review and rewrite {self.context.code_doc.filename}: {i+1}/{k} | {len(iterative_code)=}, {len(self.context.code_doc.content)=}" ) - result, rewrited_code = await self.write_code_review_and_rewrite(context_prompt, cr_prompt, self.context.code_doc.filename) + result, rewrited_code = await self.write_code_review_and_rewrite( + context_prompt, cr_prompt, self.context.code_doc.filename + ) if "LBTM" in result: iterative_code = rewrited_code elif "LGTM" in result: diff --git a/metagpt/actions/write_prd.py b/metagpt/actions/write_prd.py index bb0cf8fb91..7c160fa895 100644 --- a/metagpt/actions/write_prd.py +++ b/metagpt/actions/write_prd.py @@ -67,7 +67,7 @@ class WritePRD(Action): def __init__(self, name="", context=None, llm=None): super().__init__(name, context, llm) - async def run(self, with_messages, format=CONFIG.prompt_format, *args, **kwargs) -> ActionOutput | Message: + async def run(self, with_messages, schema=CONFIG.prompt_schema, *args, **kwargs) -> ActionOutput | Message: # Determine which requirement documents need to be rewritten: Use LLM to assess whether new requirements are # related to the PRD. If they are related, rewrite the PRD. docs_file_repo = CONFIG.git_repo.new_file_repository(relative_path=DOCS_FILE_REPO) @@ -111,7 +111,7 @@ async def run(self, with_messages, format=CONFIG.prompt_format, *args, **kwargs) # optimization in subsequent steps. return ActionOutput(content=change_files.json(), instruct_content=change_files) - async def _run_new_requirement(self, requirements, format=CONFIG.prompt_format) -> ActionOutput: + async def _run_new_requirement(self, requirements, schema=CONFIG.prompt_schema) -> ActionOutput: # sas = SearchAndSummarize() # # rsp = await sas.run(context=requirements, system_text=SEARCH_AND_SUMMARIZE_SYSTEM_EN_US) # rsp = "" @@ -121,7 +121,7 @@ async def _run_new_requirement(self, requirements, format=CONFIG.prompt_format) # logger.info(rsp) project_name = CONFIG.project_name if CONFIG.project_name else "" context = CONTEXT_TEMPLATE.format(requirements=requirements, project_name=project_name) - node = await WRITE_PRD_NODE.fill(context=context, llm=self.llm, to=format) + node = await WRITE_PRD_NODE.fill(context=context, llm=self.llm, schema=schema) await self._rename_workspace(node) return node @@ -130,11 +130,11 @@ async def _is_relative(self, new_requirement_doc, old_prd_doc) -> bool: node = await WP_IS_RELATIVE_NODE.fill(context, self.llm) return node.get("is_relative") == "YES" - async def _merge(self, new_requirement_doc, prd_doc, format=CONFIG.prompt_format) -> Document: + async def _merge(self, new_requirement_doc, prd_doc, schema=CONFIG.prompt_schema) -> Document: if not CONFIG.project_name: CONFIG.project_name = Path(CONFIG.project_path).name prompt = NEW_REQ_TEMPLATE.format(requirements=new_requirement_doc.content, old_prd=prd_doc.content) - node = await WRITE_PRD_NODE.fill(context=prompt, llm=self.llm, to=format) + node = await WRITE_PRD_NODE.fill(context=prompt, llm=self.llm, schema=schema) prd_doc.content = node.instruct_content.json(ensure_ascii=False) await self._rename_workspace(node) return prd_doc @@ -182,7 +182,7 @@ async def _rename_workspace(prd): return if not CONFIG.project_name: - if isinstance(prd, ActionOutput) or isinstance(prd, ActionNode): + if isinstance(prd, (ActionOutput, ActionNode)): ws_name = prd.instruct_content.dict()["Project Name"] else: ws_name = CodeParser.parse_str(block="Project Name", text=prd) diff --git a/metagpt/actions/write_prd_an.py b/metagpt/actions/write_prd_an.py index d96c0aeac0..edd94a463d 100644 --- a/metagpt/actions/write_prd_an.py +++ b/metagpt/actions/write_prd_an.py @@ -47,7 +47,7 @@ USER_STORIES = ActionNode( key="User Stories", expected_type=list[str], - instruction="Provide up to five scenario-based user stories.", + instruction="Provide up to 3 to 5 scenario-based user stories.", example=[ "As a user, I want to be able to choose difficulty levels", "As a player, I want to see my score after each game", @@ -57,7 +57,7 @@ COMPETITIVE_ANALYSIS = ActionNode( key="Competitive Analysis", expected_type=list[str], - instruction="Provide analyses for up to seven competitive products.", + instruction="Provide 5 to 7 competitive products.", example=["Python Snake Game: Simple interface, lacks advanced features"], ) @@ -92,8 +92,8 @@ REQUIREMENT_POOL = ActionNode( key="Requirement Pool", expected_type=list[list[str]], - instruction="List down the requirements with their priority (P0, P1, P2).", - example=[["P0", "..."], ["P1", "..."]], + instruction="List down the top-5 requirements with their priority (P0, P1, P2).", + example=[["P0", "The main code ..."], ["P0", "The game algorithm ..."]], ) UI_DESIGN_DRAFT = ActionNode( diff --git a/metagpt/config.py b/metagpt/config.py index d7f5c1249b..131854a56e 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -8,6 +8,7 @@ """ import os from copy import deepcopy +from enum import Enum from pathlib import Path from typing import Any @@ -31,6 +32,15 @@ def __init__(self, message="The required configuration is not set"): super().__init__(self.message) +class LLMProviderEnum(Enum): + OPENAI = "openai" + ANTHROPIC = "anthropic" + SPARK = "spark" + ZHIPUAI = "zhipuai" + FIREWORKS = "fireworks" + OPEN_LLM = "open_llm" + + class Config(metaclass=Singleton): """ Regular usage method: @@ -45,38 +55,55 @@ class Config(metaclass=Singleton): default_yaml_file = METAGPT_ROOT / "config/config.yaml" def __init__(self, yaml_file=default_yaml_file): - golbal_options = OPTIONS.get() + global_options = OPTIONS.get() + # cli paras + self.project_path = "" + self.project_name = "" + self.inc = False + self.reqa_file = "" + self.max_auto_summarize_code = 0 + self._init_with_config_files_and_env(yaml_file) - logger.debug("Config loading done.") self._update() - golbal_options.update(OPTIONS.get()) + global_options.update(OPTIONS.get()) + logger.debug("Config loading done.") + + def get_default_llm_provider_enum(self) -> LLMProviderEnum: + for k, v in [ + (self.openai_api_key, LLMProviderEnum.OPENAI), + (self.anthropic_api_key, LLMProviderEnum.ANTHROPIC), + (self.zhipuai_api_key, LLMProviderEnum.ZHIPUAI), + (self.fireworks_api_key, LLMProviderEnum.FIREWORKS), + (self.open_llm_api_base, LLMProviderEnum.OPEN_LLM), # reuse logic. but not a key + ]: + if self._is_valid_llm_key(k): + if self.openai_api_model: + logger.info(f"OpenAI API Model: {self.openai_api_model}") + return v + raise NotConfiguredException("You should config a LLM configuration first") + + @staticmethod + def _is_valid_llm_key(k: str) -> bool: + return k and k != "YOUR_API_KEY" def _update(self): # logger.info("Config loading done.") self.global_proxy = self._get("GLOBAL_PROXY") + self.openai_api_key = self._get("OPENAI_API_KEY") - self.anthropic_api_key = self._get("Anthropic_API_KEY") + self.anthropic_api_key = self._get("ANTHROPIC_API_KEY") self.zhipuai_api_key = self._get("ZHIPUAI_API_KEY") self.open_llm_api_base = self._get("OPEN_LLM_API_BASE") self.open_llm_api_model = self._get("OPEN_LLM_API_MODEL") self.fireworks_api_key = self._get("FIREWORKS_API_KEY") - if ( - (not self.openai_api_key or "YOUR_API_KEY" == self.openai_api_key) - and (not self.anthropic_api_key or "YOUR_API_KEY" == self.anthropic_api_key) - and (not self.zhipuai_api_key or "YOUR_API_KEY" == self.zhipuai_api_key) - and (not self.open_llm_api_base) - and (not self.fireworks_api_key or "YOUR_API_KEY" == self.fireworks_api_key) - ): - raise NotConfiguredException( - "Set OPENAI_API_KEY or Anthropic_API_KEY or ZHIPUAI_API_KEY first " - "or FIREWORKS_API_KEY or OPEN_LLM_API_BASE" - ) + _ = self.get_default_llm_provider_enum() + self.openai_api_base = self._get("OPENAI_API_BASE") self.openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy self.openai_api_type = self._get("OPENAI_API_TYPE") self.openai_api_version = self._get("OPENAI_API_VERSION") self.openai_api_rpm = self._get("RPM", 3) - self.openai_api_model = self._get("OPENAI_API_MODEL", "gpt-4") + self.openai_api_model = self._get("OPENAI_API_MODEL", "gpt-4-1106-preview") self.max_tokens_rsp = self._get("MAX_TOKENS", 2048) self.deployment_name = self._get("DEPLOYMENT_NAME") self.deployment_id = self._get("DEPLOYMENT_ID") @@ -90,7 +117,7 @@ def _update(self): self.fireworks_api_base = self._get("FIREWORKS_API_BASE") self.fireworks_api_model = self._get("FIREWORKS_API_MODEL") - self.claude_api_key = self._get("Anthropic_API_KEY") + self.claude_api_key = self._get("ANTHROPIC_API_KEY") self.serpapi_api_key = self._get("SERPAPI_API_KEY") self.serper_api_key = self._get("SERPER_API_KEY") self.google_api_key = self._get("GOOGLE_API_KEY") @@ -116,10 +143,23 @@ def _update(self): self.pyppeteer_executable_path = self._get("PYPPETEER_EXECUTABLE_PATH", "") self.repair_llm_output = self._get("REPAIR_LLM_OUTPUT", False) - self.prompt_format = self._get("PROMPT_FORMAT", "json") + self.prompt_schema = self._get("PROMPT_FORMAT", "json") self.workspace_path = Path(self._get("WORKSPACE_PATH", DEFAULT_WORKSPACE_ROOT)) self._ensure_workspace_exists() + def update_via_cli(self, project_path, project_name, inc, reqa_file, max_auto_summarize_code): + """update config via cli""" + + # Use in the PrepareDocuments action according to Section 2.2.3.5.1 of RFC 135. + if project_path: + inc = True + project_name = project_name or Path(project_path).name + self.project_path = project_path + self.project_name = project_name + self.inc = inc + self.reqa_file = reqa_file + self.max_auto_summarize_code = max_auto_summarize_code + def _ensure_workspace_exists(self): self.workspace_path.mkdir(parents=True, exist_ok=True) logger.debug(f"WORKSPACE_PATH set to {self.workspace_path}") @@ -142,8 +182,8 @@ def _init_with_config_files_and_env(self, yaml_file): @staticmethod def _get(*args, **kwargs): - m = OPTIONS.get() - return m.get(*args, **kwargs) + i = OPTIONS.get() + return i.get(*args, **kwargs) def get(self, key, *args, **kwargs): """Search for a value in config/key.yaml, config/config.yaml, and env; raise an error if not found""" @@ -156,8 +196,8 @@ def __setattr__(self, name: str, value: Any) -> None: OPTIONS.get()[name] = value def __getattr__(self, name: str) -> Any: - m = OPTIONS.get() - return m.get(name) + i = OPTIONS.get() + return i.get(name) def set_context(self, options: dict): """Update current config""" @@ -176,8 +216,8 @@ def options(self): def new_environ(self): """Return a new os.environ object""" env = os.environ.copy() - m = self.options - env.update({k: v for k, v in m.items() if isinstance(v, str)}) + i = self.options + env.update({k: v for k, v in i.items() if isinstance(v, str)}) return env diff --git a/metagpt/inspect_module.py b/metagpt/inspect_module.py deleted file mode 100644 index 48ceffc571..0000000000 --- a/metagpt/inspect_module.py +++ /dev/null @@ -1,28 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -@Time : 2023/5/28 14:54 -@Author : alexanderwu -@File : inspect_module.py -""" - -import inspect - -import metagpt # replace with your module - - -def print_classes_and_functions(module): - """FIXME: NOT WORK..""" - for name, obj in inspect.getmembers(module): - if inspect.isclass(obj): - print(f"Class: {name}") - elif inspect.isfunction(obj): - print(f"Function: {name}") - else: - print(name) - - print(dir(module)) - - -if __name__ == "__main__": - print_classes_and_functions(metagpt) diff --git a/metagpt/llm.py b/metagpt/llm.py index 7c0ad7975d..8763642f0a 100644 --- a/metagpt/llm.py +++ b/metagpt/llm.py @@ -6,32 +6,14 @@ @File : llm.py """ -from metagpt.config import CONFIG +from metagpt.config import CONFIG, LLMProviderEnum from metagpt.provider.base_gpt_api import BaseGPTAPI -from metagpt.provider.fireworks_api import FireWorksGPTAPI from metagpt.provider.human_provider import HumanProvider -from metagpt.provider.open_llm_api import OpenLLMGPTAPI -from metagpt.provider.openai_api import OpenAIGPTAPI -from metagpt.provider.spark_api import SparkAPI -from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI +from metagpt.provider.llm_provider_registry import LLM_REGISTRY _ = HumanProvider() # Avoid pre-commit error -def LLM() -> BaseGPTAPI: - """initialize different LLM instance according to the key field existence""" - # TODO a little trick, can use registry to initialize LLM instance further - if CONFIG.openai_api_key: - llm = OpenAIGPTAPI() - elif CONFIG.spark_api_key: - llm = SparkAPI() - elif CONFIG.zhipuai_api_key: - llm = ZhiPuAIGPTAPI() - elif CONFIG.open_llm_api_base: - llm = OpenLLMGPTAPI() - elif CONFIG.fireworks_api_key: - llm = FireWorksGPTAPI() - else: - raise RuntimeError("You should config a LLM configuration first") - - return llm +def LLM(provider: LLMProviderEnum = CONFIG.get_default_llm_provider_enum()) -> BaseGPTAPI: + """get the default llm provider""" + return LLM_REGISTRY.get_provider(provider) diff --git a/metagpt/manager.py b/metagpt/manager.py deleted file mode 100644 index a063608bed..0000000000 --- a/metagpt/manager.py +++ /dev/null @@ -1,66 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -@Time : 2023/5/11 14:42 -@Author : alexanderwu -@File : manager.py -""" -from metagpt.llm import LLM -from metagpt.logs import logger -from metagpt.schema import Message - - -class Manager: - def __init__(self, llm: LLM = LLM()): - self.llm = llm # Large Language Model - self.role_directions = { - "User": "Product Manager", - "Product Manager": "Architect", - "Architect": "Engineer", - "Engineer": "QA Engineer", - "QA Engineer": "Product Manager", - } - self.prompt_template = """ - Given the following message: - {message} - - And the current status of roles: - {roles} - - Which role should handle this message? - """ - - async def handle(self, message: Message, environment): - """ - 管理员处理信息,现在简单的将信息递交给下一个人 - The administrator processes the information, now simply passes the information on to the next person - :param message: - :param environment: - :return: - """ - # Get all roles from the environment - roles = environment.get_roles() - # logger.debug(f"{roles=}, {message=}") - - # Build a context for the LLM to understand the situation - # context = { - # "message": str(message), - # "roles": {role.name: role.get_info() for role in roles}, - # } - # Ask the LLM to decide which role should handle the message - # chosen_role_name = self.llm.ask(self.prompt_template.format(context)) - - # FIXME: 现在通过简单的字典决定流向,但之后还是应该有思考过程 - # The direction of flow is now determined by a simple dictionary, but there should still be a thought process afterwards - next_role_profile = self.role_directions[message.role] - # logger.debug(f"{next_role_profile}") - for _, role in roles.items(): - if next_role_profile == role.profile: - next_role = role - break - else: - logger.error(f"No available role can handle message: {message}.") - return - - # Find the chosen role and handle the message - return await next_role.handle(message) diff --git a/metagpt/memory/longterm_memory.py b/metagpt/memory/longterm_memory.py index 22032a86e3..ab2214261f 100644 --- a/metagpt/memory/longterm_memory.py +++ b/metagpt/memory/longterm_memory.py @@ -19,7 +19,7 @@ class LongTermMemory(Memory): def __init__(self): self.memory_storage: MemoryStorage = MemoryStorage() - super(LongTermMemory, self).__init__() + super().__init__() self.rc = None # RoleContext self.msg_from_recover = False @@ -37,7 +37,7 @@ def recover_memory(self, role_id: str, rc: "RoleContext"): self.msg_from_recover = False def add(self, message: Message): - super(LongTermMemory, self).add(message) + super().add(message) for action in self.rc.watch: if message.cause_by == action and not self.msg_from_recover: # currently, only add role's watching messages to its memory_storage @@ -50,7 +50,7 @@ def find_news(self, observed: list[Message], k=0) -> list[Message]: 1. find the short-term memory(stm) news 2. furthermore, filter out similar messages based on ltm(long-term memory), get the final news """ - stm_news = super(LongTermMemory, self).find_news(observed, k=k) # shot-term memory news + stm_news = super().find_news(observed, k=k) # shot-term memory news if not self.memory_storage.is_initialized: # memory_storage hasn't initialized, use default `find_news` to get stm_news return stm_news @@ -64,9 +64,9 @@ def find_news(self, observed: list[Message], k=0) -> list[Message]: return ltm_news[-k:] def delete(self, message: Message): - super(LongTermMemory, self).delete(message) + super().delete(message) # TODO delete message in memory_storage def clear(self): - super(LongTermMemory, self).clear() + super().clear() self.memory_storage.clean() diff --git a/metagpt/memory/memory_storage.py b/metagpt/memory/memory_storage.py index a213f6d7a7..fafb335687 100644 --- a/metagpt/memory/memory_storage.py +++ b/metagpt/memory/memory_storage.py @@ -58,7 +58,7 @@ def _get_index_and_store_fname(self): return index_fpath, storage_fpath def persist(self): - super(MemoryStorage, self).persist() + super().persist() logger.debug(f"Agent {self.role_id} persist memory into local") def add(self, message: Message) -> bool: diff --git a/metagpt/provider/anthropic_api.py b/metagpt/provider/anthropic_api.py index 03802a716e..f5b06c8559 100644 --- a/metagpt/provider/anthropic_api.py +++ b/metagpt/provider/anthropic_api.py @@ -14,7 +14,7 @@ class Claude2: def ask(self, prompt): - client = Anthropic(api_key=CONFIG.claude_api_key) + client = Anthropic(api_key=CONFIG.anthropic_api_key) res = client.completions.create( model="claude-2", @@ -24,7 +24,7 @@ def ask(self, prompt): return res.completion async def aask(self, prompt): - client = Anthropic(api_key=CONFIG.claude_api_key) + client = Anthropic(api_key=CONFIG.anthropic_api_key) res = client.completions.create( model="claude-2", diff --git a/metagpt/provider/fireworks_api.py b/metagpt/provider/fireworks_api.py index 47ac9cf616..a76151666e 100644 --- a/metagpt/provider/fireworks_api.py +++ b/metagpt/provider/fireworks_api.py @@ -4,10 +4,12 @@ import openai -from metagpt.config import CONFIG +from metagpt.config import CONFIG, LLMProviderEnum +from metagpt.provider.llm_provider_registry import register_provider from metagpt.provider.openai_api import CostManager, OpenAIGPTAPI, RateLimiter +@register_provider(LLMProviderEnum.FIREWORKS) class FireWorksGPTAPI(OpenAIGPTAPI): def __init__(self): self.__init_fireworks(CONFIG) diff --git a/metagpt/provider/llm_provider_registry.py b/metagpt/provider/llm_provider_registry.py new file mode 100644 index 0000000000..2b3ef93a36 --- /dev/null +++ b/metagpt/provider/llm_provider_registry.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/19 17:26 +@Author : alexanderwu +@File : llm_provider_registry.py +""" +from metagpt.config import LLMProviderEnum + + +class LLMProviderRegistry: + def __init__(self): + self.providers = {} + + def register(self, key, provider_cls): + self.providers[key] = provider_cls + + def get_provider(self, enum: LLMProviderEnum): + """get provider instance according to the enum""" + return self.providers[enum]() + + +# Registry instance +LLM_REGISTRY = LLMProviderRegistry() + + +def register_provider(key): + """register provider to registry""" + + def decorator(cls): + LLM_REGISTRY.register(key, cls) + return cls + + return decorator diff --git a/metagpt/provider/open_llm_api.py b/metagpt/provider/open_llm_api.py index f421e30c8c..bada0e294c 100644 --- a/metagpt/provider/open_llm_api.py +++ b/metagpt/provider/open_llm_api.py @@ -4,8 +4,9 @@ import openai -from metagpt.config import CONFIG +from metagpt.config import CONFIG, LLMProviderEnum from metagpt.logs import logger +from metagpt.provider.llm_provider_registry import register_provider from metagpt.provider.openai_api import CostManager, OpenAIGPTAPI, RateLimiter @@ -31,6 +32,7 @@ def update_cost(self, prompt_tokens, completion_tokens, model): CONFIG.total_cost = self.total_cost +@register_provider(LLMProviderEnum.OPEN_LLM) class OpenLLMGPTAPI(OpenAIGPTAPI): def __init__(self): self.__init_openllm(CONFIG) diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index a73bb0aa0c..0be70b3cab 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -18,10 +18,11 @@ wait_random_exponential, ) -from metagpt.config import CONFIG +from metagpt.config import CONFIG, LLMProviderEnum from metagpt.logs import logger from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA, GENERAL_TOOL_CHOICE +from metagpt.provider.llm_provider_registry import register_provider from metagpt.schema import Message from metagpt.utils.singleton import Singleton from metagpt.utils.token_counter import ( @@ -137,6 +138,7 @@ def log_and_reraise(retry_state): raise retry_state.outcome.exception() +@register_provider(LLMProviderEnum.OPENAI) class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): """ Check https://platform.openai.com/examples for examples @@ -329,7 +331,8 @@ def _calc_usage(self, messages: list[dict], rsp: str) -> dict: usage["completion_tokens"] = completion_tokens return usage except Exception as e: - logger.error("usage calculation failed!", e) + logger.error(f"{self.model} usage calculation failed!", e) + return {} else: return usage @@ -360,7 +363,7 @@ async def acompletion_batch_text(self, batch: list[list[dict]]) -> list[str]: return results def _update_costs(self, usage: dict): - if CONFIG.calc_usage: + if CONFIG.calc_usage and usage: try: prompt_tokens = int(usage["prompt_tokens"]) completion_tokens = int(usage["completion_tokens"]) diff --git a/metagpt/provider/postprecess/base_postprecess_plugin.py b/metagpt/provider/postprecess/base_postprecess_plugin.py index 0d1cfbb110..721476507d 100644 --- a/metagpt/provider/postprecess/base_postprecess_plugin.py +++ b/metagpt/provider/postprecess/base_postprecess_plugin.py @@ -44,7 +44,7 @@ def run_extract_content_from_output(self, content: str, right_key: str) -> str: def run_retry_parse_json_text(self, content: str) -> Union[dict, list]: """inherited class can re-implement the function""" - logger.info(f"extracted json CONTENT from output:\n{content}") + logger.debug(f"extracted json CONTENT from output:\n{content}") parsed_data = retry_parse_json_text(output=content) # should use output=content return parsed_data diff --git a/metagpt/provider/spark_api.py b/metagpt/provider/spark_api.py index 60c86f4dcb..484fa7956a 100644 --- a/metagpt/provider/spark_api.py +++ b/metagpt/provider/spark_api.py @@ -19,11 +19,13 @@ import websocket # 使用websocket_client -from metagpt.config import CONFIG +from metagpt.config import CONFIG, LLMProviderEnum from metagpt.logs import logger from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.llm_provider_registry import register_provider +@register_provider(LLMProviderEnum.SPARK) class SparkAPI(BaseGPTAPI): def __init__(self): logger.warning("当前方法无法支持异步运行。当你使用acompletion时,并不能并行访问。") diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index 92119b764d..eef0e51e16 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -16,9 +16,10 @@ wait_random_exponential, ) -from metagpt.config import CONFIG +from metagpt.config import CONFIG, LLMProviderEnum from metagpt.logs import logger from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.llm_provider_registry import register_provider from metagpt.provider.openai_api import CostManager, log_and_reraise from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI @@ -30,6 +31,7 @@ class ZhiPuEvent(Enum): FINISH = "finish" +@register_provider(LLMProviderEnum.ZHIPUAI) class ZhiPuAIGPTAPI(BaseGPTAPI): """ Refs to `https://open.bigmodel.cn/dev/api#chatglm_turbo` diff --git a/metagpt/repo_parser.py b/metagpt/repo_parser.py index b84dbab9a2..3524a5bce3 100644 --- a/metagpt/repo_parser.py +++ b/metagpt/repo_parser.py @@ -15,17 +15,17 @@ from metagpt.config import CONFIG from metagpt.logs import logger +from metagpt.utils.exceptions import handle_exception class RepoParser(BaseModel): base_directory: Path = Field(default=None) - def parse_file(self, file_path): + @classmethod + @handle_exception(exception_type=Exception, default_return=[]) + def _parse_file(cls, file_path: Path) -> list: """Parse a Python file in the repository.""" - try: - return ast.parse(file_path.read_text()).body - except: - return [] + return ast.parse(file_path.read_text()).body def extract_class_and_function_info(self, tree, file_path): """Extract class, function, and global variable information from the AST.""" @@ -52,7 +52,7 @@ def generate_symbols(self): files_classes = [] directory = self.base_directory for path in directory.rglob("*.py"): - tree = self.parse_file(path) + tree = self._parse_file(path) file_info = self.extract_class_and_function_info(tree, path) files_classes.append(file_info) @@ -90,5 +90,10 @@ def main(): logger.info(pformat(symbols)) +def error(): + """raise Exception and logs it""" + RepoParser._parse_file(Path("test.py")) + + if __name__ == "__main__": main() diff --git a/metagpt/roles/architect.py b/metagpt/roles/architect.py index fa91d393d4..fce6c34256 100644 --- a/metagpt/roles/architect.py +++ b/metagpt/roles/architect.py @@ -28,7 +28,7 @@ def __init__( profile: str = "Architect", goal: str = "design a concise, usable, complete software system", constraints: str = "make sure the architecture is simple enough and use appropriate open source libraries." - "Use same language as user requirement" + "Use same language as user requirement", ) -> None: """Initializes the Architect with given attributes.""" super().__init__(name, profile, goal, constraints) diff --git a/metagpt/roles/engineer.py b/metagpt/roles/engineer.py index f1e65b177f..2620fe4e53 100644 --- a/metagpt/roles/engineer.py +++ b/metagpt/roles/engineer.py @@ -73,7 +73,7 @@ def __init__( profile: str = "Engineer", goal: str = "write elegant, readable, extensible, efficient code", constraints: str = "the code should conform to standards like google-style and be modular and maintainable. " - "Use same language as user requirement", + "Use same language as user requirement", n_borg: int = 1, use_code_review: bool = False, ) -> None: diff --git a/metagpt/roles/product_manager.py b/metagpt/roles/product_manager.py index e5e9f2b5e7..7858d2caa7 100644 --- a/metagpt/roles/product_manager.py +++ b/metagpt/roles/product_manager.py @@ -54,4 +54,4 @@ async def _think(self) -> None: return self._rc.todo async def _observe(self, ignore_memory=False) -> int: - return await super(ProductManager, self)._observe(ignore_memory=True) + return await super()._observe(ignore_memory=True) diff --git a/metagpt/roles/project_manager.py b/metagpt/roles/project_manager.py index 5a2b9be500..6577375132 100644 --- a/metagpt/roles/project_manager.py +++ b/metagpt/roles/project_manager.py @@ -26,7 +26,7 @@ def __init__( name: str = "Eve", profile: str = "Project Manager", goal: str = "break down tasks according to PRD/technical design, generate a task list, and analyze task " - "dependencies to start with the prerequisite modules", + "dependencies to start with the prerequisite modules", constraints: str = "use same language as user requirement", ) -> None: """ diff --git a/metagpt/roles/qa_engineer.py b/metagpt/roles/qa_engineer.py index 4439b9b19c..71b474a3be 100644 --- a/metagpt/roles/qa_engineer.py +++ b/metagpt/roles/qa_engineer.py @@ -178,4 +178,4 @@ async def _act(self) -> Message: async def _observe(self, ignore_memory=False) -> int: # This role has events that trigger and execute themselves based on conditions, and cannot rely on the # content of memory to activate. - return await super(QaEngineer, self)._observe(ignore_memory=True) + return await super()._observe(ignore_memory=True) diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 1e7ebf711d..bf37a6637a 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -25,9 +25,8 @@ from pydantic import BaseModel, Field -from metagpt.actions import Action, ActionOutput +from metagpt.actions import Action, ActionOutput, UserRequirement from metagpt.actions.action_node import ActionNode -from metagpt.actions.add_requirement import UserRequirement from metagpt.llm import LLM, HumanProvider from metagpt.logs import logger from metagpt.memory import Memory @@ -127,17 +126,7 @@ def history(self) -> list[Message]: return self.memory.get() -class _RoleInjector(type): - def __call__(cls, *args, **kwargs): - instance = super().__call__(*args, **kwargs) - - if not instance._rc.watch: - instance._watch([UserRequirement]) - - return instance - - -class Role(metaclass=_RoleInjector): +class Role: """Role/Agent""" def __init__(self, name="", profile="", goal="", constraints="", desc="", is_human=False): @@ -149,16 +138,15 @@ def __init__(self, name="", profile="", goal="", constraints="", desc="", is_hum self._states = [] self._actions = [] self._role_id = str(self._setting) - self._rc = RoleContext() + self._rc = RoleContext(watch={any_to_str(UserRequirement)}) self._subscription = {any_to_str(self), name} if name else {any_to_str(self)} - def _reset(self): self._states = [] self._actions = [] def _init_action_system_message(self, action: Action): - action.set_prefix(self._get_prefix(), self.profile) + action.set_prefix(self._get_prefix()) def _init_actions(self, actions): self._reset() @@ -203,8 +191,7 @@ def _watch(self, actions: Iterable[Type[Action]]): """Watch Actions of interest. Role will select Messages caused by these Actions from its personal message buffer during _observe. """ - tags = {any_to_str(t) for t in actions} - self._rc.watch.update(tags) + self._rc.watch = {any_to_str(t) for t in actions} # check RoleContext after adding watch actions self._rc.check(self._role_id) @@ -280,7 +267,7 @@ async def _think(self) -> None: async def _act(self) -> Message: logger.info(f"{self._setting}: ready to {self._rc.todo}") response = await self._rc.todo.run(self._rc.important_memory) - if isinstance(response, ActionOutput) or isinstance(response, ActionNode): + if isinstance(response, (ActionOutput, ActionNode)): msg = Message( content=response.content, instruct_content=response.instruct_content, @@ -401,6 +388,8 @@ async def run(self, with_message=None): msg = with_message elif isinstance(with_message, list): msg = Message("\n".join(with_message)) + if not msg.cause_by: + msg.cause_by = UserRequirement self.put_message(msg) if not await self._observe(): diff --git a/metagpt/roles/searcher.py b/metagpt/roles/searcher.py index 5760202ff7..31de8e8968 100644 --- a/metagpt/roles/searcher.py +++ b/metagpt/roles/searcher.py @@ -59,7 +59,7 @@ async def _act_sp(self) -> Message: logger.info(f"{self._setting}: ready to {self._rc.todo}") response = await self._rc.todo.run(self._rc.memory.get(k=0)) - if isinstance(response, ActionOutput) or isinstance(response, ActionNode): + if isinstance(response, (ActionOutput, ActionNode)): msg = Message( content=response.content, instruct_content=response.instruct_content, diff --git a/metagpt/schema.py b/metagpt/schema.py index 5aec378e4e..d2f8d33e6d 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -18,10 +18,11 @@ import json import os.path import uuid +from abc import ABC from asyncio import Queue, QueueEmpty, wait_for from json import JSONDecodeError from pathlib import Path -from typing import Dict, List, Optional, Set, TypedDict +from typing import Dict, List, Optional, Set, Type, TypedDict, TypeVar from pydantic import BaseModel, Field @@ -36,6 +37,7 @@ ) from metagpt.logs import logger from metagpt.utils.common import any_to_str, any_to_str_set +from metagpt.utils.exceptions import handle_exception class RawMessage(TypedDict): @@ -121,10 +123,6 @@ def __init__( :param send_to: Specifies the target recipient or consumer for message delivery in the environment. :param role: Message meta info tells who sent this message. """ - if not cause_by: - from metagpt.actions import UserRequirement - cause_by = UserRequirement - super().__init__( id=uuid.uuid4().hex, content=content, @@ -164,14 +162,11 @@ def dump(self) -> str: return self.json(exclude_none=True) @staticmethod + @handle_exception(exception_type=JSONDecodeError, default_return=None) def load(val): """Convert the json string to object.""" - try: - d = json.loads(val) - return Message(**d) - except JSONDecodeError as err: - logger.error(f"parse json failed: {val}, error:{err}") - return None + i = json.loads(val) + return Message(**i) class UserMessage(Message): @@ -253,50 +248,46 @@ async def dump(self) -> str: return json.dumps(lst) @staticmethod - def load(self, v) -> "MessageQueue": + def load(data) -> "MessageQueue": """Convert the json string to the `MessageQueue` object.""" - q = MessageQueue() + queue = MessageQueue() try: - lst = json.loads(v) + lst = json.loads(data) for i in lst: msg = Message(**i) - q.push(msg) + queue.push(msg) except JSONDecodeError as e: - logger.warning(f"JSON load failed: {v}, error:{e}") + logger.warning(f"JSON load failed: {data}, error:{e}") + + return queue + + +# 定义一个泛型类型变量 +T = TypeVar("T", bound="BaseModel") + - return q +class BaseContext(BaseModel, ABC): + @classmethod + @handle_exception + def loads(cls: Type[T], val: str) -> Optional[T]: + i = json.loads(val) + return cls(**i) -class CodingContext(BaseModel): +class CodingContext(BaseContext): filename: str design_doc: Optional[Document] task_doc: Optional[Document] code_doc: Optional[Document] - @staticmethod - def loads(val: str) -> CodingContext | None: - try: - m = json.loads(val) - return CodingContext(**m) - except Exception: - return None - -class TestingContext(BaseModel): +class TestingContext(BaseContext): filename: str code_doc: Document test_doc: Optional[Document] - @staticmethod - def loads(val: str) -> TestingContext | None: - try: - m = json.loads(val) - return TestingContext(**m) - except Exception: - return None - -class RunCodeContext(BaseModel): +class RunCodeContext(BaseContext): mode: str = "script" code: Optional[str] code_filename: str = "" @@ -308,28 +299,12 @@ class RunCodeContext(BaseModel): output_filename: Optional[str] output: Optional[str] - @staticmethod - def loads(val: str) -> RunCodeContext | None: - try: - m = json.loads(val) - return RunCodeContext(**m) - except Exception: - return None - -class RunCodeResult(BaseModel): +class RunCodeResult(BaseContext): summary: str stdout: str stderr: str - @staticmethod - def loads(val: str) -> RunCodeResult | None: - try: - m = json.loads(val) - return RunCodeResult(**m) - except Exception: - return None - class CodeSummarizeContext(BaseModel): design_filename: str = "" @@ -353,5 +328,5 @@ def __hash__(self): return hash((self.design_filename, self.task_filename)) -class BugFixContext(BaseModel): +class BugFixContext(BaseContext): filename: str = "" diff --git a/metagpt/startup.py b/metagpt/startup.py index f930c386b7..a1af90ffcf 100644 --- a/metagpt/startup.py +++ b/metagpt/startup.py @@ -1,13 +1,12 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- import asyncio -from pathlib import Path import typer from metagpt.config import CONFIG -app = typer.Typer() +app = typer.Typer(add_completion=False) @app.command() @@ -22,12 +21,15 @@ def startup( inc: bool = typer.Option(default=False, help="Incremental mode. Use it to coop with existing repo."), project_path: str = typer.Option( default="", - help="Specify the directory path of the old version project to fulfill the " "incremental requirements.", + help="Specify the directory path of the old version project to fulfill the incremental requirements.", + ), + reqa_file: str = typer.Option( + default="", help="Specify the source file name for rewriting the quality assurance code." ), - reqa_file: str = typer.Option(default="", help="Specify the source file name for rewriting the quality test code."), max_auto_summarize_code: int = typer.Option( - default=-1, - help="The maximum number of times the 'SummarizeCode' action is automatically invoked, with -1 indicating unlimited. This parameter is used for debugging the workflow.", + default=0, + help="The maximum number of times the 'SummarizeCode' action is automatically invoked, with -1 indicating " + "unlimited. This parameter is used for debugging the workflow.", ), ): """Run a startup. Be a boss.""" @@ -40,15 +42,7 @@ def startup( ) from metagpt.team import Team - # Use in the PrepareDocuments action according to Section 2.2.3.5.1 of RFC 135. - CONFIG.project_path = project_path - if project_path: - inc = True - project_name = project_name or Path(project_path).name - CONFIG.project_name = project_name - CONFIG.inc = inc - CONFIG.reqa_file = reqa_file - CONFIG.max_auto_summarize_code = max_auto_summarize_code + CONFIG.update_via_cli(project_path, project_name, inc, reqa_file, max_auto_summarize_code) company = Team() company.hire( diff --git a/metagpt/team.py b/metagpt/team.py index a5c405f803..570be4e6b1 100644 --- a/metagpt/team.py +++ b/metagpt/team.py @@ -3,10 +3,11 @@ """ @Time : 2023/5/12 00:30 @Author : alexanderwu -@File : software_company.py +@File : team.py @Modified By: mashenquan, 2023/11/27. Add an archiving operation after completing the project, as specified in Section 2.2.3.3 of RFC 135. """ +import warnings from pydantic import BaseModel, Field from metagpt.actions import UserRequirement @@ -21,8 +22,8 @@ class Team(BaseModel): """ - Team: Possesses one or more roles (agents), SOP (Standard Operating Procedures), and a platform for instant messaging, - dedicated to perform any multi-agent activity, such as collaboratively writing executable code. + Team: Possesses one or more roles (agents), SOP (Standard Operating Procedures), and a env for instant messaging, + dedicated to env any multi-agent activity, such as collaboratively writing executable code. """ env: Environment = Field(default_factory=Environment) @@ -47,7 +48,7 @@ def _check_balance(self): raise NoMoneyException(CONFIG.total_cost, f"Insufficient funds: {CONFIG.max_budget}") def run_project(self, idea, send_to: str = ""): - """Start a project from publishing user requirement.""" + """Run a project from publishing user requirement.""" self.idea = idea # Human requirement. @@ -55,6 +56,16 @@ def run_project(self, idea, send_to: str = ""): Message(role="Human", content=idea, cause_by=UserRequirement, send_to=send_to or MESSAGE_ROUTE_TO_ALL) ) + def start_project(self, idea, send_to: str = ""): + """ + Deprecated: This method will be removed in the future. + Please use the `run_project` method instead. + """ + warnings.warn("The 'start_project' method is deprecated and will be removed in the future. " + "Please use the 'run_project' method instead.", + DeprecationWarning, stacklevel=2) + return self.run_project(idea=idea, send_to=send_to) + def _save(self): logger.info(self.json(ensure_ascii=False)) diff --git a/metagpt/actions/azure_tts.py b/metagpt/tools/azure_tts.py similarity index 65% rename from metagpt/actions/azure_tts.py rename to metagpt/tools/azure_tts.py index daa3f6892f..e59d98016c 100644 --- a/metagpt/actions/azure_tts.py +++ b/metagpt/tools/azure_tts.py @@ -7,19 +7,16 @@ """ from azure.cognitiveservices.speech import AudioConfig, SpeechConfig, SpeechSynthesizer -from metagpt.actions.action import Action -from metagpt.config import Config +from metagpt.config import CONFIG -class AzureTTS(Action): - def __init__(self, name, context=None, llm=None): - super().__init__(name, context, llm) - self.config = Config() +class AzureTTS: + """https://learn.microsoft.com/zh-cn/azure/cognitive-services/speech-service/language-support?tabs=tts#voice-styles-and-roles""" - # Parameters reference: https://learn.microsoft.com/zh-cn/azure/cognitive-services/speech-service/language-support?tabs=tts#voice-styles-and-roles - def synthesize_speech(self, lang, voice, role, text, output_file): - subscription_key = self.config.get("AZURE_TTS_SUBSCRIPTION_KEY") - region = self.config.get("AZURE_TTS_REGION") + @classmethod + def synthesize_speech(cls, lang, voice, role, text, output_file): + subscription_key = CONFIG.get("AZURE_TTS_SUBSCRIPTION_KEY") + region = CONFIG.get("AZURE_TTS_REGION") speech_config = SpeechConfig(subscription=subscription_key, region=region) speech_config.speech_synthesis_voice_name = voice @@ -41,5 +38,5 @@ def synthesize_speech(self, lang, voice, role, text, output_file): if __name__ == "__main__": - azure_tts = AzureTTS("azure_tts") + azure_tts = AzureTTS() azure_tts.synthesize_speech("zh-CN", "zh-CN-YunxiNeural", "Boy", "Hello, I am Kaka", "output.wav") diff --git a/metagpt/tools/search_engine_meilisearch.py b/metagpt/tools/search_engine_meilisearch.py index f7c1c685a2..ea6db4dbd0 100644 --- a/metagpt/tools/search_engine_meilisearch.py +++ b/metagpt/tools/search_engine_meilisearch.py @@ -11,6 +11,8 @@ import meilisearch from meilisearch.index import Index +from metagpt.utils.exceptions import handle_exception + class DataSource: def __init__(self, name: str, url: str): @@ -34,11 +36,7 @@ def add_documents(self, data_source: DataSource, documents: List[dict]): index.add_documents(documents) self.set_index(index) + @handle_exception(exception_type=Exception, default_return=[]) def search(self, query): - try: - search_results = self._index.search(query) - return search_results["hits"] - except Exception as e: - # Handle MeiliSearch API errors - print(f"MeiliSearch API error: {e}") - return [] + search_results = self._index.search(query) + return search_results["hits"] diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index a9bdd6e2dc..fa18694e3a 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -17,10 +17,16 @@ import os import platform import re +import typing from typing import List, Tuple, Union +import aiofiles +import loguru +from tenacity import RetryCallState, _utils + from metagpt.const import MESSAGE_ROUTE_TO_ALL from metagpt.logs import logger +from metagpt.utils.exceptions import handle_exception def check_cmd_exists(command) -> int: @@ -191,7 +197,7 @@ def extract_struct(cls, text: str, data_type: Union[type(list), type(dict)]) -> result = ast.literal_eval(structure_text) # Ensure the result matches the specified data type - if isinstance(result, list) or isinstance(result, dict): + if isinstance(result, (list, dict)): return result raise ValueError(f"The extracted structure is not a {data_type}.") @@ -291,9 +297,6 @@ def __str__(self): def print_members(module, indent=0): """ https://stackoverflow.com/questions/1796180/how-can-i-get-a-list-of-all-classes-within-current-module-in-python - :param module: - :param indent: - :return: """ prefix = " " * indent for name, obj in inspect.getmembers(module): @@ -311,6 +314,7 @@ def print_members(module, indent=0): def parse_recipient(text): + # FIXME: use ActionNode instead. pattern = r"## Send To:\s*([A-Za-z]+)\s*?" # hard code for now recipient = re.search(pattern, text) if recipient: @@ -327,18 +331,12 @@ def get_class_name(cls) -> str: return f"{cls.__module__}.{cls.__name__}" -def get_object_name(obj) -> str: - """Return class name of the object""" - cls = type(obj) - return f"{cls.__module__}.{cls.__name__}" - - -def any_to_str(val) -> str: +def any_to_str(val: str | typing.Callable) -> str: """Return the class name or the class name of the object, or 'val' if it's a string type.""" if isinstance(val, str): return val if not callable(val): - return get_object_name(val) + return get_class_name(type(val)) return get_class_name(val) @@ -346,20 +344,68 @@ def any_to_str(val) -> str: def any_to_str_set(val) -> set: """Convert any type to string set.""" res = set() - if isinstance(val, dict) or isinstance(val, list) or isinstance(val, set) or isinstance(val, tuple): + + # Check if the value is iterable, but not a string (since strings are technically iterable) + if isinstance(val, (dict, list, set, tuple)): + # Special handling for dictionaries to iterate over values + if isinstance(val, dict): + val = val.values() + for i in val: res.add(any_to_str(i)) else: res.add(any_to_str(val)) + return res -def is_subscribed(message, tags): +def is_subscribed(message: "Message", tags: set): """Return whether it's consumer""" if MESSAGE_ROUTE_TO_ALL in message.send_to: return True - for t in tags: - if t in message.send_to: + for i in tags: + if i in message.send_to: return True return False + + +def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> typing.Callable[["RetryCallState"], None]: + """ + Generates a logging function to be used after a call is retried. + + This generated function logs an error message with the outcome of the retried function call. It includes + the name of the function, the time taken for the call in seconds (formatted according to `sec_format`), + the number of attempts made, and the exception raised, if any. + + :param i: A Logger instance from the loguru library used to log the error message. + :param sec_format: A string format specifier for how to format the number of seconds since the start of the call. + Defaults to three decimal places. + :return: A callable that accepts a RetryCallState object and returns None. This callable logs the details + of the retried call. + """ + + def log_it(retry_state: "RetryCallState") -> None: + # If the function name is not known, default to "" + if retry_state.fn is None: + fn_name = "" + else: + # Retrieve the callable's name using a utility function + fn_name = _utils.get_callback_name(retry_state.fn) + + # Log an error message with the function name, time since start, attempt number, and the exception + i.error( + f"Finished call to '{fn_name}' after {sec_format % retry_state.seconds_since_start}(s), " + f"this was the {_utils.to_ordinal(retry_state.attempt_number)} time calling it. " + f"exp: {retry_state.outcome.exception()}" + ) + + return log_it + + +@handle_exception +async def aread(file_path: str) -> str: + """Read file asynchronously.""" + async with aiofiles.open(str(file_path), mode="r") as reader: + content = await reader.read() + return content diff --git a/metagpt/utils/custom_decoder.py b/metagpt/utils/custom_decoder.py index 373d16356f..eb01a1115c 100644 --- a/metagpt/utils/custom_decoder.py +++ b/metagpt/utils/custom_decoder.py @@ -25,7 +25,7 @@ def _scan_once(string, idx): except IndexError: raise StopIteration(idx) from None - if nextchar == '"' or nextchar == "'": + if nextchar in ("'", '"'): if idx + 2 < len(string) and string[idx + 1] == nextchar and string[idx + 2] == nextchar: # Handle the case where the next two characters are the same as nextchar return parse_string(string, idx + 3, strict, delimiter=nextchar * 3) # triple quote diff --git a/metagpt/utils/dependency_file.py b/metagpt/utils/dependency_file.py index e8347d5679..8a6575e9e8 100644 --- a/metagpt/utils/dependency_file.py +++ b/metagpt/utils/dependency_file.py @@ -15,7 +15,8 @@ import aiofiles from metagpt.config import CONFIG -from metagpt.logs import logger +from metagpt.utils.common import aread +from metagpt.utils.exceptions import handle_exception class DependencyFile: @@ -36,21 +37,14 @@ async def load(self): """Load dependencies from the file asynchronously.""" if not self._filename.exists(): return - try: - async with aiofiles.open(str(self._filename), mode="r") as reader: - data = await reader.read() - self._dependencies = json.loads(data) - except Exception as e: - logger.error(f"Failed to load {str(self._filename)}, error:{e}") + self._dependencies = json.loads(await aread(self._filename)) + @handle_exception async def save(self): """Save dependencies to the file asynchronously.""" - try: - data = json.dumps(self._dependencies) - async with aiofiles.open(str(self._filename), mode="w") as writer: - await writer.write(data) - except Exception as e: - logger.error(f"Failed to save {str(self._filename)}, error:{e}") + data = json.dumps(self._dependencies) + async with aiofiles.open(str(self._filename), mode="w") as writer: + await writer.write(data) async def update(self, filename: Path | str, dependencies: Set[Path | str], persist=True): """Update dependencies for a file asynchronously. diff --git a/metagpt/utils/exceptions.py b/metagpt/utils/exceptions.py new file mode 100644 index 0000000000..b4b5aa590c --- /dev/null +++ b/metagpt/utils/exceptions.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/19 14:46 +@Author : alexanderwu +@File : exceptions.py +""" + + +import asyncio +import functools +import traceback +from typing import Any, Callable, Tuple, Type, TypeVar, Union + +from metagpt.logs import logger + +ReturnType = TypeVar("ReturnType") + + +def handle_exception( + _func: Callable[..., ReturnType] = None, + *, + exception_type: Union[Type[Exception], Tuple[Type[Exception], ...]] = Exception, + default_return: Any = None, +) -> Callable[..., ReturnType]: + """handle exception, return default value""" + + def decorator(func: Callable[..., ReturnType]) -> Callable[..., ReturnType]: + @functools.wraps(func) + async def async_wrapper(*args: Any, **kwargs: Any) -> ReturnType: + try: + return await func(*args, **kwargs) + except exception_type as e: + logger.opt(depth=1).error( + f"Calling {func.__name__} with args: {args}, kwargs: {kwargs} failed: {e}, " + f"stack: {traceback.format_exc()}" + ) + return default_return + + @functools.wraps(func) + def sync_wrapper(*args: Any, **kwargs: Any) -> ReturnType: + try: + return func(*args, **kwargs) + except exception_type as e: + logger.opt(depth=1).error( + f"Calling {func.__name__} with args: {args}, kwargs: {kwargs} failed: {e}, " + f"stack: {traceback.format_exc()}" + ) + return default_return + + if asyncio.iscoroutinefunction(func): + return async_wrapper + else: + return sync_wrapper + + if _func is None: + return decorator + else: + return decorator(_func) diff --git a/metagpt/utils/file.py b/metagpt/utils/file.py index 6bb9a1a973..f62b44eb86 100644 --- a/metagpt/utils/file.py +++ b/metagpt/utils/file.py @@ -11,6 +11,7 @@ import aiofiles from metagpt.logs import logger +from metagpt.utils.exceptions import handle_exception class File: @@ -19,6 +20,7 @@ class File: CHUNK_SIZE = 64 * 1024 @classmethod + @handle_exception async def write(cls, root_path: Path, filename: str, content: bytes) -> Path: """Write the file content to the local specified path. @@ -33,18 +35,15 @@ async def write(cls, root_path: Path, filename: str, content: bytes) -> Path: Raises: Exception: If an unexpected error occurs during the file writing process. """ - try: - root_path.mkdir(parents=True, exist_ok=True) - full_path = root_path / filename - async with aiofiles.open(full_path, mode="wb") as writer: - await writer.write(content) - logger.debug(f"Successfully write file: {full_path}") - return full_path - except Exception as e: - logger.error(f"Error writing file: {e}") - raise e + root_path.mkdir(parents=True, exist_ok=True) + full_path = root_path / filename + async with aiofiles.open(full_path, mode="wb") as writer: + await writer.write(content) + logger.debug(f"Successfully write file: {full_path}") + return full_path @classmethod + @handle_exception async def read(cls, file_path: Path, chunk_size: int = None) -> bytes: """Partitioning read the file content from the local specified path. @@ -58,18 +57,14 @@ async def read(cls, file_path: Path, chunk_size: int = None) -> bytes: Raises: Exception: If an unexpected error occurs during the file reading process. """ - try: - chunk_size = chunk_size or cls.CHUNK_SIZE - async with aiofiles.open(file_path, mode="rb") as reader: - chunks = list() - while True: - chunk = await reader.read(chunk_size) - if not chunk: - break - chunks.append(chunk) - content = b"".join(chunks) - logger.debug(f"Successfully read file, the path of file: {file_path}") - return content - except Exception as e: - logger.error(f"Error reading file: {e}") - raise e + chunk_size = chunk_size or cls.CHUNK_SIZE + async with aiofiles.open(file_path, mode="rb") as reader: + chunks = list() + while True: + chunk = await reader.read(chunk_size) + if not chunk: + break + chunks.append(chunk) + content = b"".join(chunks) + logger.debug(f"Successfully read file, the path of file: {file_path}") + return content diff --git a/metagpt/utils/file_repository.py b/metagpt/utils/file_repository.py index 2eca799a88..099556a6ba 100644 --- a/metagpt/utils/file_repository.py +++ b/metagpt/utils/file_repository.py @@ -19,6 +19,7 @@ from metagpt.config import CONFIG from metagpt.logs import logger from metagpt.schema import Document +from metagpt.utils.common import aread from metagpt.utils.json_to_markdown import json_to_markdown @@ -97,15 +98,7 @@ async def get(self, filename: Path | str) -> Document | None: path_name = self.workdir / filename if not path_name.exists(): return None - try: - async with aiofiles.open(str(path_name), mode="r") as reader: - doc.content = await reader.read() - except FileNotFoundError as e: - logger.info(f"open {str(path_name)} failed:{e}") - return None - except Exception as e: - logger.info(f"open {str(path_name)} failed:{e}") - return None + doc.content = await aread(path_name) return doc async def get_all(self) -> List[Document]: diff --git a/metagpt/utils/get_template.py b/metagpt/utils/get_template.py index 86c1915f75..7e05e5d5e5 100644 --- a/metagpt/utils/get_template.py +++ b/metagpt/utils/get_template.py @@ -8,10 +8,10 @@ from metagpt.config import CONFIG -def get_template(templates, format=CONFIG.prompt_format): - selected_templates = templates.get(format) +def get_template(templates, schema=CONFIG.prompt_schema): + selected_templates = templates.get(schema) if selected_templates is None: - raise ValueError(f"Can't find {format} in passed in templates") + raise ValueError(f"Can't find {schema} in passed in templates") # Extract the selected templates prompt_template = selected_templates["PROMPT_TEMPLATE"] diff --git a/metagpt/utils/serialize.py b/metagpt/utils/serialize.py index 124176fcb0..5e52846e18 100644 --- a/metagpt/utils/serialize.py +++ b/metagpt/utils/serialize.py @@ -6,7 +6,7 @@ import pickle from typing import Dict, List -from metagpt.actions.action_output import ActionOutput +from metagpt.actions.action_node import ActionNode from metagpt.schema import Message @@ -60,7 +60,7 @@ def deserialize_message(message_ser: str) -> Message: message = pickle.loads(message_ser) if message.instruct_content: ic = message.instruct_content - ic_obj = ActionOutput.create_model_class(class_name=ic["class"], mapping=ic["mapping"]) + ic_obj = ActionNode.create_model_class(class_name=ic["class"], mapping=ic["mapping"]) ic_new = ic_obj(**ic["value"]) message.instruct_content = ic_new diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index 266a53268a..ebfb85de7d 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -56,6 +56,7 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0613"): if model in { "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", + "gpt-3.5-turbo-16k", "gpt-3.5-turbo-1106", "gpt-4-0314", "gpt-4-32k-0314", @@ -63,7 +64,7 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0613"): "gpt-4-32k-0613", "gpt-4-1106-preview", }: - tokens_per_message = 3 + tokens_per_message = 3 # # every reply is primed with <|start|>assistant<|message|> tokens_per_name = 1 elif model == "gpt-3.5-turbo-0301": tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n diff --git a/metagpt/utils/utils.py b/metagpt/utils/utils.py deleted file mode 100644 index 5ceed65d9a..0000000000 --- a/metagpt/utils/utils.py +++ /dev/null @@ -1,22 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# @Desc : - -import typing - -from tenacity import _utils - - -def general_after_log(logger: "loguru.Logger", sec_format: str = "%0.3f") -> typing.Callable[["RetryCallState"], None]: - def log_it(retry_state: "RetryCallState") -> None: - if retry_state.fn is None: - fn_name = "" - else: - fn_name = _utils.get_callback_name(retry_state.fn) - logger.error( - f"Finished call to '{fn_name}' after {sec_format % retry_state.seconds_since_start}(s), " - f"this was the {_utils.to_ordinal(retry_state.attempt_number)} time calling it. " - f"exp: {retry_state.outcome.exception()}" - ) - - return log_it diff --git a/requirements-ocr.txt b/requirements-ocr.txt deleted file mode 100644 index cf6103afc5..0000000000 --- a/requirements-ocr.txt +++ /dev/null @@ -1,4 +0,0 @@ -paddlepaddle==2.4.2 -paddleocr>=2.0.1 -tabulate==0.9.0 --r requirements.txt diff --git a/setup.py b/setup.py index 730fffd35a..8ef2a69469 100644 --- a/setup.py +++ b/setup.py @@ -30,15 +30,15 @@ def run(self): setup( name="metagpt", - version="0.5.0", - description="The Multi-Role Meta Programming Framework", + version="0.5.2", + description="The Multi-Agent Framework", long_description=long_description, long_description_content_type="text/markdown", url="https://github.com/geekan/MetaGPT", author="Alexander Wu", author_email="alexanderwu@deepwisdom.ai", license="MIT", - keywords="metagpt multi-role multi-agent programming gpt llm metaprogramming", + keywords="metagpt multi-agent multi-role programming gpt llm metaprogramming", packages=find_packages(exclude=["contrib", "docs", "examples", "tests*"]), python_requires=">=3.9", install_requires=requirements, @@ -48,6 +48,7 @@ def run(self): "search-google": ["google-api-python-client==2.94.0"], "search-ddg": ["duckduckgo-search==3.8.5"], "pyppeteer": ["pyppeteer>=1.0.2"], + "ocr": ["paddlepaddle==2.4.2", "paddleocr>=2.0.1", "tabulate==0.9.0"], }, cmdclass={ "install_mermaid": InstallMermaidCLI, diff --git a/tests/metagpt/actions/test_action_output.py b/tests/metagpt/actions/test_action_output.py index ef8e239bde..f1765cb03e 100644 --- a/tests/metagpt/actions/test_action_output.py +++ b/tests/metagpt/actions/test_action_output.py @@ -7,7 +7,7 @@ """ from typing import List, Tuple -from metagpt.actions import ActionOutput +from metagpt.actions.action_node import ActionNode t_dict = { "Required Python third-party packages": '"""\nflask==1.1.2\npygame==2.0.1\n"""\n', @@ -37,12 +37,12 @@ def test_create_model_class(): - test_class = ActionOutput.create_model_class("test_class", WRITE_TASKS_OUTPUT_MAPPING) + test_class = ActionNode.create_model_class("test_class", WRITE_TASKS_OUTPUT_MAPPING) assert test_class.__name__ == "test_class" def test_create_model_class_with_mapping(): - t = ActionOutput.create_model_class("test_class_1", WRITE_TASKS_OUTPUT_MAPPING) + t = ActionNode.create_model_class("test_class_1", WRITE_TASKS_OUTPUT_MAPPING) t1 = t(**t_dict) value = t1.dict()["Task list"] assert value == ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"] diff --git a/tests/metagpt/actions/test_azure_tts.py b/tests/metagpt/actions/test_azure_tts.py index bcafe10f52..9995e96912 100644 --- a/tests/metagpt/actions/test_azure_tts.py +++ b/tests/metagpt/actions/test_azure_tts.py @@ -5,11 +5,11 @@ @Author : alexanderwu @File : test_azure_tts.py """ -from metagpt.actions.azure_tts import AzureTTS +from metagpt.tools.azure_tts import AzureTTS def test_azure_tts(): - azure_tts = AzureTTS("azure_tts") + azure_tts = AzureTTS() azure_tts.synthesize_speech("zh-CN", "zh-CN-YunxiNeural", "Boy", "你好,我是卡卡", "output.wav") # 运行需要先配置 SUBSCRIPTION_KEY diff --git a/tests/metagpt/actions/test_detail_mining.py b/tests/metagpt/actions/test_detail_mining.py index 891dca6cac..a178ec8402 100644 --- a/tests/metagpt/actions/test_detail_mining.py +++ b/tests/metagpt/actions/test_detail_mining.py @@ -3,21 +3,27 @@ """ @Time : 2023/9/13 00:26 @Author : fisherdeng -@File : test_detail_mining.py +@File : test_generate_questions.py """ import pytest -from metagpt.actions.detail_mining import DetailMining +from metagpt.actions.generate_questions import GenerateQuestions from metagpt.logs import logger +context = """ +## topic +如何做一个生日蛋糕 + +## record +我认为应该先准备好材料,然后再开始做蛋糕。 +""" + @pytest.mark.asyncio -async def test_detail_mining(): - topic = "如何做一个生日蛋糕" - record = "我认为应该先准备好材料,然后再开始做蛋糕。" - detail_mining = DetailMining("detail_mining") - rsp = await detail_mining.run(topic=topic, record=record) +async def test_generate_questions(): + detail_mining = GenerateQuestions() + rsp = await detail_mining.run(context) logger.info(f"{rsp.content=}") - assert "##OUTPUT" in rsp.content - assert "蛋糕" in rsp.content + assert "Questions" in rsp.content + assert "1." in rsp.content diff --git a/tests/metagpt/actions/test_prepare_interview.py b/tests/metagpt/actions/test_prepare_interview.py new file mode 100644 index 0000000000..7c32882e08 --- /dev/null +++ b/tests/metagpt/actions/test_prepare_interview.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/9/13 00:26 +@Author : fisherdeng +@File : test_detail_mining.py +""" +import pytest + +from metagpt.actions.prepare_interview import PrepareInterview +from metagpt.logs import logger + + +@pytest.mark.asyncio +async def test_prepare_interview(): + action = PrepareInterview() + rsp = await action.run("I just graduated and hope to find a job as a Python engineer") + logger.info(f"{rsp.content=}") + + assert "Questions" in rsp.content + assert "1." in rsp.content diff --git a/tests/metagpt/memory/test_memory_storage.py b/tests/metagpt/memory/test_memory_storage.py index c67ca689f3..7b74eb512b 100644 --- a/tests/metagpt/memory/test_memory_storage.py +++ b/tests/metagpt/memory/test_memory_storage.py @@ -8,7 +8,7 @@ from typing import List from metagpt.actions import UserRequirement, WritePRD -from metagpt.actions.action_output import ActionOutput +from metagpt.actions.action_node import ActionNode from metagpt.memory.memory_storage import MemoryStorage from metagpt.schema import Message @@ -42,7 +42,7 @@ def test_idea_message(): def test_actionout_message(): out_mapping = {"field1": (str, ...), "field2": (List[str], ...)} out_data = {"field1": "field1 value", "field2": ["field2 value1", "field2 value2"]} - ic_obj = ActionOutput.create_model_class("prd", out_mapping) + ic_obj = ActionNode.create_model_class("prd", out_mapping) role_id = "UTUser2(Architect)" content = "The user has requested the creation of a command-line interface (CLI) snake game" diff --git a/tests/metagpt/roles/ui_role.py b/tests/metagpt/roles/ui_role.py index 8ac799bf3a..0932efa1f0 100644 --- a/tests/metagpt/roles/ui_role.py +++ b/tests/metagpt/roles/ui_role.py @@ -10,6 +10,7 @@ from metagpt.actions import Action, ActionOutput, WritePRD # from metagpt.const import WORKSPACE_ROOT +from metagpt.actions.action_node import ActionNode from metagpt.config import CONFIG from metagpt.logs import logger from metagpt.roles import Role @@ -17,44 +18,38 @@ from metagpt.tools.sd_engine import SDEngine PROMPT_TEMPLATE = """ -# Context {context} -## Format example -{format_example} ------ -Role: You are a UserInterface Designer; the goal is to finish a UI design according to PRD, give a design description, and select specified elements and UI style. -Requirements: Based on the context, fill in the following missing information, provide detailed HTML and CSS code -Attention: Use '##' to split sections, not '#', and '## ' SHOULD WRITE BEFORE the code and triple quote. - -## UI Design Description:Provide as Plain text, place the design objective here -## Selected Elements:Provide as Plain text, up to 5 specified elements, clear and simple -## HTML Layout:Provide as Plain text, use standard HTML code -## CSS Styles (styles.css):Provide as Plain text,use standard css code -## Anything UNCLEAR:Provide as Plain text. Try to clarify it. - +## Role +You are a UserInterface Designer; the goal is to finish a UI design according to PRD, give a design description, and select specified elements and UI style. """ -FORMAT_EXAMPLE = """ - -## UI Design Description -```Snake games are classic and addictive games with simple yet engaging elements. Here are the main elements commonly found in snake games ``` - -## Selected Elements - -Game Grid: The game grid is a rectangular... - -Snake: The player controls a snake that moves across the grid... - -Food: Food items (often represented as small objects or differently colored blocks) - -Score: The player's score increases each time the snake eats a piece of food. The longer the snake becomes, the higher the score. - -Game Over: The game ends when the snake collides with itself or an obstacle. At this point, the player's final score is displayed, and they are given the option to restart the game. - - -## HTML Layout - +UI_DESIGN_DESC = ActionNode( + key="UI Design Desc", + expected_type=str, + instruction="place the design objective here", + example="Snake games are classic and addictive games with simple yet engaging elements. Here are the main elements" + " commonly found in snake games", +) + +SELECTED_ELEMENTS = ActionNode( + key="Selected Elements", + expected_type=list[str], + instruction="up to 5 specified elements, clear and simple", + example=[ + "Game Grid: The game grid is a rectangular...", + "Snake: The player controls a snake that moves across the grid...", + "Food: Food items (often represented as small objects or differently colored blocks)", + "Score: The player's score increases each time the snake eats a piece of food. The longer the snake becomes, the higher the score.", + "Game Over: The game ends when the snake collides with itself or an obstacle. At this point, the player's final score is displayed, and they are given the option to restart the game.", + ], +) + +HTML_LAYOUT = ActionNode( + key="HTML Layout", + expected_type=str, + instruction="use standard HTML code", + example=""" @@ -71,9 +66,14 @@ - -## CSS Styles (styles.css) -body { +""", +) + +CSS_STYLES = ActionNode( + key="CSS Styles", + expected_type=str, + instruction="use standard css code", + example="""body { display: flex; justify-content: center; align-items: center; @@ -121,19 +121,25 @@ color: #ff0000; display: none; } +""", +) -## Anything UNCLEAR -There are no unclear points. +ANYTHING_UNCLEAR = ActionNode( + key="Anything UNCLEAR", + expected_type=str, + instruction="Mention any aspects of the project that are unclear and try to clarify them.", + example="...", +) -""" +NODES = [ + UI_DESIGN_DESC, + SELECTED_ELEMENTS, + HTML_LAYOUT, + CSS_STYLES, + ANYTHING_UNCLEAR, +] -OUTPUT_MAPPING = { - "UI Design Description": (str, ...), - "Selected Elements": (str, ...), - "HTML Layout": (str, ...), - "CSS Styles (styles.css)": (str, ...), - "Anything UNCLEAR": (str, ...), -} +UI_DESIGN_NODE = ActionNode.from_children("UI_DESIGN", NODES) def load_engine(func): @@ -223,10 +229,8 @@ async def _save(self, css_content, html_content): css_file_path = save_dir / "ui_design.css" html_file_path = save_dir / "ui_design.html" - with open(css_file_path, "w") as css_file: - css_file.write(css_content) - with open(html_file_path, "w") as html_file: - html_file.write(html_content) + css_file_path.write_text(css_content) + html_file_path.write_text(html_content) async def run(self, requirements: list[Message], *args, **kwargs) -> ActionOutput: """Run the UI Design action.""" @@ -234,9 +238,9 @@ async def run(self, requirements: list[Message], *args, **kwargs) -> ActionOutpu context = requirements[-1].content ui_design_draft = self.parse_requirement(context=context) # todo: parse requirements str - prompt = PROMPT_TEMPLATE.format(context=ui_design_draft, format_example=FORMAT_EXAMPLE) + prompt = PROMPT_TEMPLATE.format(context=ui_design_draft) logger.info(prompt) - ui_describe = await self._aask_v1(prompt, "ui_design", OUTPUT_MAPPING) + ui_describe = await UI_DESIGN_NODE.fill(prompt) logger.info(ui_describe.content) logger.info(ui_describe.instruct_content) css = self.parse_css_code(context=ui_describe.content) diff --git a/tests/metagpt/test_role.py b/tests/metagpt/test_role.py index 8fac2503cf..dbe45130da 100644 --- a/tests/metagpt/test_role.py +++ b/tests/metagpt/test_role.py @@ -14,11 +14,11 @@ import pytest from pydantic import BaseModel -from metagpt.actions import Action, ActionOutput +from metagpt.actions import Action, ActionOutput, UserRequirement from metagpt.environment import Environment from metagpt.roles import Role from metagpt.schema import Message -from metagpt.utils.common import get_class_name +from metagpt.utils.common import any_to_str class MockAction(Action): @@ -60,7 +60,7 @@ class Input(BaseModel): name=seed.name, profile=seed.profile, goal=seed.goal, constraints=seed.constraints, desc=seed.desc ) role.subscribe({seed.subscription}) - assert role._rc.watch == set({}) + assert role._rc.watch == {any_to_str(UserRequirement)} assert role.name == seed.name assert role.profile == seed.profile assert role._setting.goal == seed.goal @@ -88,13 +88,13 @@ class Input(BaseModel): @pytest.mark.asyncio async def test_msg_to(): m = Message(content="a", send_to=["a", MockRole, Message]) - assert m.send_to == set({"a", get_class_name(MockRole), get_class_name(Message)}) + assert m.send_to == {"a", any_to_str(MockRole), any_to_str(Message)} m = Message(content="a", cause_by=MockAction, send_to={"a", MockRole, Message}) - assert m.send_to == set({"a", get_class_name(MockRole), get_class_name(Message)}) + assert m.send_to == {"a", any_to_str(MockRole), any_to_str(Message)} m = Message(content="a", send_to=("a", MockRole, Message)) - assert m.send_to == set({"a", get_class_name(MockRole), get_class_name(Message)}) + assert m.send_to == {"a", any_to_str(MockRole), any_to_str(Message)} if __name__ == "__main__": diff --git a/tests/metagpt/test_schema.py b/tests/metagpt/test_schema.py index 51ebd5baa5..79421fab25 100644 --- a/tests/metagpt/test_schema.py +++ b/tests/metagpt/test_schema.py @@ -13,7 +13,7 @@ from metagpt.actions import Action from metagpt.schema import AIMessage, Message, SystemMessage, UserMessage -from metagpt.utils.common import get_class_name +from metagpt.utils.common import any_to_str @pytest.mark.asyncio @@ -54,9 +54,9 @@ def test_message(): m.cause_by = "Message" assert m.cause_by == "Message" m.cause_by = Action - assert m.cause_by == get_class_name(Action) + assert m.cause_by == any_to_str(Action) m.cause_by = Action() - assert m.cause_by == get_class_name(Action) + assert m.cause_by == any_to_str(Action) m.content = "b" assert m.content == "b" @@ -67,7 +67,7 @@ def test_routes(): m.send_to = "b" assert m.send_to == {"b"} m.send_to = {"e", Action} - assert m.send_to == {"e", get_class_name(Action)} + assert m.send_to == {"e", any_to_str(Action)} if __name__ == "__main__": diff --git a/tests/metagpt/utils/test_serialize.py b/tests/metagpt/utils/test_serialize.py index ffa34866cf..f027d53f86 100644 --- a/tests/metagpt/utils/test_serialize.py +++ b/tests/metagpt/utils/test_serialize.py @@ -7,7 +7,7 @@ from typing import List, Tuple from metagpt.actions import WritePRD -from metagpt.actions.action_output import ActionOutput +from metagpt.actions.action_node import ActionNode from metagpt.schema import Message from metagpt.utils.serialize import ( actionoutout_schema_to_mapping, @@ -54,7 +54,7 @@ def test_actionoutout_schema_to_mapping(): def test_serialize_and_deserialize_message(): out_mapping = {"field1": (str, ...), "field2": (List[str], ...)} out_data = {"field1": "field1 value", "field2": ["field2 value1", "field2 value2"]} - ic_obj = ActionOutput.create_model_class("prd", out_mapping) + ic_obj = ActionNode.create_model_class("prd", out_mapping) message = Message( content="prd demand", instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD