-
Notifications
You must be signed in to change notification settings - Fork 3k
[Hackathon 5th No.73] ToT #7660
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 14 commits
Commits
Show all changes
54 commits
Select commit
Hold shift + click to select a range
522a7af
Hackathon TASK73 ToT
ErnestinaQiu a68dacb
update readme tutorial
ErnestinaQiu c65870b
modify according to Lint
ErnestinaQiu 315cc3e
modify according Link
ErnestinaQiu b001719
Delete LICENSE
ErnestinaQiu 8932aca
Update LICENSE
ErnestinaQiu da288af
black format
ErnestinaQiu 916d70c
isort format
ErnestinaQiu e64c51e
Update search_crosswords-dfs.ipynb
ErnestinaQiu ef5cfa6
update files formats
ErnestinaQiu 96a6d35
Update LICENSE
ErnestinaQiu 6c95517
Update LICENSE
ErnestinaQiu 1728255
Update LICENSE
ErnestinaQiu bd35dde
Update LICENSE
ErnestinaQiu f5ff4df
delete test data
ErnestinaQiu 5621975
delete some unnecessary files
ErnestinaQiu e7d3ba6
add paddlenlp-llama2
ErnestinaQiu 84ee4d0
fix one bug
ErnestinaQiu effc87b
fix outputs bug
ErnestinaQiu c514ca9
delete meta/llama2
ErnestinaQiu 402fa97
modify according to comments
ErnestinaQiu ae8c242
change according to comments
ErnestinaQiu c7979ed
Delete .gitignore
ErnestinaQiu c8f79e2
Create .gitignore
ErnestinaQiu 994286c
Move directory
1f9499a
Add tree of thoughts scripts
065a1e9
add first dir
ErnestinaQiu 3fd243d
Merge branch 'develop' of https://github.com/ErnestinaQiu/tot into de…
ErnestinaQiu 24982e4
add note
ErnestinaQiu cebe49e
Update README.md
ErnestinaQiu 987117e
Update requirements.txt
ErnestinaQiu c74e478
Update demo.py
ErnestinaQiu 26179dc
Update .gitignore
ErnestinaQiu e1fdd67
Update run.py
ErnestinaQiu e793463
Update __init__.py
ErnestinaQiu 5e4dcf1
chat templates
ErnestinaQiu 671c6b8
add Ernie
ErnestinaQiu 1face8b
Update llama.py
ErnestinaQiu e9b2100
Update bfs.py
ErnestinaQiu 94d7a82
Update models.py
ErnestinaQiu 4472dd6
Update run.py
ErnestinaQiu 780ed41
format style
ErnestinaQiu 8e90744
format style
ErnestinaQiu 8fc80a3
format style
ErnestinaQiu e53ad12
format style
ErnestinaQiu c29e61c
format style
ErnestinaQiu 1e9c384
format style
ErnestinaQiu 94153a9
format style
ErnestinaQiu 0592931
format style
ErnestinaQiu b1a65a9
删掉重复的“测试结果”
ErnestinaQiu fac1c02
删除Ernie的token,设置环境变量解决
ErnestinaQiu fc122c3
format style
ErnestinaQiu df25595
format style
ErnestinaQiu 34c0953
删除注释掉的代码
ErnestinaQiu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| /*__pycache__/ | ||
| dist/ | ||
| src/tree_of_thoughts_llm.egg-info/ | ||
| .env | ||
| *.pyc | ||
| *.DS_Store | ||
| spark/test.py |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| include src/tot/data/24/24.csv | ||
| include src/tot/data/crosswords/mini0505_0_100_5.json | ||
| include src/tot/data/crosswords/mini0505.json | ||
| include src/tot/data/text/data_100_random_text.txt |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,85 @@ | ||
| # Tree of Thoughts (ToT) | ||
|
|
||
|  | ||
|
|
||
| Official implementation for paper [Tree of Thoughts: Deliberate Problem Solving with Large Language Models](https://arxiv.org/abs/2305.10601) with code, prompts, model outputs. | ||
| Also check [its tweet thread](https://twitter.com/ShunyuYao12/status/1659357547474681857) in 1min. | ||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
| ## Setup | ||
| 1. Set up OpenAI API key and store in environment variable ``OPENAI_API_KEY`` (see [here](https://help.openai.com/en/articles/5112595-best-practices-for-api-key-safety)). | ||
|
|
||
| 2. Install `tot` package in two ways: | ||
| - Option 1: Install from PyPI | ||
| ```bash | ||
| pip install tree-of-thoughts-llm | ||
| ``` | ||
| - Option 2: Install from source | ||
| ```bash | ||
| git clone https://github.com/PaddlePaddle/PaddleNLP.git | ||
| cd pipelines/examples/agents | ||
| pip install -r requirements.txt | ||
| pip install -e . # install `tot` package | ||
| ``` | ||
| 3. Intall meta/llama2 according to facebook tutorial. And then modify the model path in the llm_config.yaml | ||
|
|
||
| ## Quick Start | ||
| The following minimal script will attempt to solve the game of 24 with `4 5 6 10` (might be a bit slow as it's using llama-7b-chat): | ||
|
|
||
|
|
||
| run in pipelines/examples/agents/tree-of-thought-llm | ||
|
|
||
| ``` | ||
| python demo.py | ||
| ``` | ||
| the detail code is the following | ||
|
|
||
| ```python | ||
| import argparse | ||
| from tot.methods.bfs import solve | ||
| from tot.tasks.game24 import Game24Task | ||
|
|
||
| args = argparse.Namespace(backend='llama-2-7b-chat', temperature=0.6, task='game24', naive_run=False, prompt_sample=None, method_generate='propose', method_evaluate='value', method_select='greedy', n_generate_sample=1, n_evaluate_sample=3, n_select_sample=5) | ||
|
|
||
|
|
||
| task = Game24Task() | ||
| ys, infos = solve(args, task, 900) | ||
| print(ys[0]) | ||
| ``` | ||
|
|
||
| And the output would be something like (note it's not deterministic, and sometimes the output can be wrong): | ||
| ``` | ||
| 10 - 4 = 6 (left: 5 6 6) | ||
| 5 * 6 = 30 (left: 6 30) | ||
| 30 - 6 = 24 (left: 24) | ||
| Answer: (5 * (10 - 4)) - 6 = 24 | ||
| ``` | ||
|
|
||
| ## Paper Experiments | ||
|
|
||
| Run experiments via ``sh scripts/{game24, text, crosswords}/{standard_sampling, cot_sampling, bfs}.sh``, except in crosswords we use a DFS algorithm for ToT, which can be run via ``scripts/crosswords/search_crosswords-dfs.ipynb``. | ||
|
|
||
| The very simple ``run.py`` implements the ToT + BFS algorithm, as well as the naive IO/CoT sampling. Some key arguments: | ||
|
|
||
| - ``--naive_run``: if True, run naive IO/CoT sampling instead of ToT + BFS. | ||
| - ``--prompt_sample`` (choices=[``standard``, ``cot``]): sampling prompt | ||
| - ``--method_generate`` (choices=[``sample``, ``propose``]): thought generator, whether to sample independent thoughts (used in Creative Writing) or propose sequential thoughts (used in Game of 24) | ||
| - ``--method_evaluate`` (choices=[``value``, ``vote``]): state evaluator, whether to use the value states independently (used in Game of 24) or vote on states together (used in Creative Writing) | ||
| - ``--n_generate_sample``: number of times to prompt for thought generation | ||
| - ``--n_evaluate_sample``: number of times to prompt for state evaluation | ||
| - ``--n_select_sample``: number of states to keep from each step (i.e. ``b`` in the paper's ToT + BFS algorithm) | ||
|
|
||
|
|
||
|
|
||
| ## Paper Trajectories | ||
| ``logs/`` contains all the trajectories from the paper's experiments, except for ``logs/game24/gpt-4_0.7_propose1_value3_greedy5_start900_end1000.json`` which was reproduced after the paper (as the original experiment was done in a notebook) and achieved a 69\% score instead of the original 74\% score due to randomness in GPT decoding. We hope to aggregate multiple runs in the future to account for sampling randomness and update the paper, but this shouldn't affect the main conclusions of the paper. | ||
|
|
||
| ## How to Add A New Task | ||
| Setting up a new task is easy, and mainly involves two steps. | ||
| * Set up a new task class in ``tot/tasks/`` and task files in ``tot/data/``. See ``tot/tasks/game24.py`` for an example. Add the task to ``tot/tasks/__init__.py``. | ||
| * Set up task-specific prompts in ``tot/prompts/``. See ``tot/prompts/game24.py`` for an example. Depending on the nature of the task, choose ``--method_generate`` (choices=[``sample``, ``propose``]) and ``--method_evaluate`` (choices=[``value``, ``vote``]) and their corresponding prompts. | ||
|
|
||
| If there are any questions, please contact ErnestinaQiu by [email protected] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| import argparse | ||
|
|
||
| from src.tot.methods.bfs import solve | ||
| from src.tot.tasks.game24 import Game24Task | ||
|
|
||
| args = argparse.Namespace( | ||
| backend="llama-2-7b-chat", | ||
| temperature=0.6, | ||
| task="game24", | ||
| naive_run=False, | ||
| prompt_sample=None, | ||
| method_generate="propose", | ||
| method_evaluate="value", | ||
| method_select="greedy", | ||
| n_generate_sample=1, | ||
| n_evaluate_sample=3, | ||
| n_select_sample=5, | ||
| ) | ||
|
|
||
| task = Game24Task() | ||
| ys, infos = solve(args, task, 900) | ||
| print(ys[0]) | ||
| print(infos) |
104 changes: 104 additions & 0 deletions
104
pipelines/examples/agents/tree-of-thought-llm/llama2/__init__.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,104 @@ | ||
| """ | ||
| author: Ernestina | ||
| des: 1) set configure 2) initiate llama2 | ||
| """ | ||
| import os | ||
| import time | ||
| from typing import List, Optional | ||
|
|
||
| import yaml | ||
| from llama2.llama.llama import Dialog, Llama | ||
|
|
||
| os.environ["WORLD_SIZE"] = "1" | ||
| os.environ["RANK"] = "0" | ||
| os.environ["MASTER_ADDR"] = "localhost" | ||
| os.environ["MASTER_PORT"] = "8020" | ||
|
|
||
| llm_config_path = os.path.join(os.getcwd(), "llm_config.yml") | ||
| with open(llm_config_path, "r") as f: | ||
| log_config = yaml.full_load(f.read()) | ||
|
|
||
|
|
||
| class ChatCompletion: | ||
| global log_config | ||
| global max_seq_len | ||
| global max_batch_size | ||
|
|
||
| def __init__(self, model="llama-2-7b-chat") -> None: | ||
| ckpt_dir = log_config[model]["ckpt_dir"] | ||
| tokenizer_path = log_config[model]["tokenizer_path"] | ||
| # ckpt_dir = f"/mnt/e/study/dl/llama2/{model}/" | ||
| # tokenizer_path = "/mnt/e/study/dl/llama2/tokenizer.model" | ||
| max_seq_len = 1000 | ||
| max_batch_size = 6 | ||
| self.generator = Llama.build( | ||
| ckpt_dir=ckpt_dir, | ||
| tokenizer_path=tokenizer_path, | ||
| max_seq_len=max_seq_len, | ||
| max_batch_size=max_batch_size, | ||
| ) | ||
|
|
||
| # @staticmethod | ||
| def create( | ||
| self, | ||
| messages: List[Dialog], | ||
| temperature: float = 0.6, | ||
| top_p: float = 0.9, | ||
| max_gen_len: Optional[int] = None, | ||
| ): | ||
| """ | ||
| Entry point of the program for generating text using a pretrained model. | ||
|
|
||
| Args: | ||
| messages (list): There are two roles including "system" and "user". | ||
| --Example [[{"role": "user", "content": "what is the recipe of mayonnaise?"}, {"role": "system", "content": "Always answer with Haiku"}]] | ||
| ckpt_dir (str): The directory containing checkpoint files for the pretrained model. | ||
| tokenizer_path (str): The path to the tokenizer model used for text encoding/decoding. | ||
| temperature (float, optional): The temperature value for controlling randomness in generation. | ||
| Defaults to 0.6. | ||
| top_p (float, optional): The top-p sampling parameter for controlling diversity in generation. | ||
| Defaults to 0.9. | ||
| max_seq_len (int, optional): The maximum sequence length for input prompts. Defaults to 512. | ||
| max_batch_size (int, optional): The maximum batch size for generating sequences. Defaults to 8. | ||
| max_gen_len (int, optional): The maximum length of generated sequences. If None, it will be | ||
| set to the model's max sequence length. Defaults to None. | ||
| """ | ||
| results = self.generator.chat_completion( | ||
| messages, # type: ignore | ||
| max_gen_len=max_gen_len, | ||
| temperature=temperature, | ||
| top_p=top_p, | ||
| ) | ||
|
|
||
| completion = { | ||
| "choices": [], | ||
| "created": time.time(), | ||
| "id": "llama2_{}".format(int(time.time())), | ||
| "model": "llama-2-7b-chat", | ||
| "object": "chat.completion", | ||
| "usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0}, | ||
| } | ||
|
|
||
| assert len(messages) == len(results) | ||
| for i in range(len(results)): | ||
| dialog = messages[i] | ||
| print(f"dialog: \n {dialog}") | ||
| result = results[i] | ||
| if i == len(results) - 1: | ||
| finish_reason = "stop" | ||
| else: | ||
| finish_reason = "length" | ||
| tmp = { | ||
| "finish_reason": finish_reason, | ||
| "index": i, | ||
| "message": {"content": "", "role": ""}, | ||
| } | ||
| tmp["message"]["role"] = result["generation"]["role"] | ||
| tmp["message"]["content"] = result["generation"]["content"].replace( | ||
| "\n", "" | ||
| ) | ||
|
|
||
| completion["choices"].append(tmp) | ||
| print(f"\n result: \n {result}") | ||
|
|
||
| return completion |
160 changes: 160 additions & 0 deletions
160
pipelines/examples/agents/tree-of-thought-llm/llama2/llama/.gitignore
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,160 @@ | ||
| # Byte-compiled / optimized / DLL files | ||
| __pycache__/ | ||
| *.py[cod] | ||
| *$py.class | ||
|
|
||
| # C extensions | ||
| *.so | ||
|
|
||
| # Distribution / packaging | ||
| .Python | ||
| build/ | ||
| develop-eggs/ | ||
| dist/ | ||
| downloads/ | ||
| eggs/ | ||
| .eggs/ | ||
| lib/ | ||
| lib64/ | ||
| parts/ | ||
| sdist/ | ||
| var/ | ||
| wheels/ | ||
| share/python-wheels/ | ||
| *.egg-info/ | ||
| .installed.cfg | ||
| *.egg | ||
| MANIFEST | ||
|
|
||
| # PyInstaller | ||
| # Usually these files are written by a python script from a template | ||
| # before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
| *.manifest | ||
| *.spec | ||
|
|
||
| # Installer logs | ||
| pip-log.txt | ||
| pip-delete-this-directory.txt | ||
|
|
||
| # Unit test / coverage reports | ||
| htmlcov/ | ||
| .tox/ | ||
| .nox/ | ||
| .coverage | ||
| .coverage.* | ||
| .cache | ||
| nosetests.xml | ||
| coverage.xml | ||
| *.cover | ||
| *.py,cover | ||
| .hypothesis/ | ||
| .pytest_cache/ | ||
| cover/ | ||
|
|
||
| # Translations | ||
| *.mo | ||
| *.pot | ||
|
|
||
| # Django stuff: | ||
| *.log | ||
| local_settings.py | ||
| db.sqlite3 | ||
| db.sqlite3-journal | ||
|
|
||
| # Flask stuff: | ||
| instance/ | ||
| .webassets-cache | ||
|
|
||
| # Scrapy stuff: | ||
| .scrapy | ||
|
|
||
| # Sphinx documentation | ||
| docs/_build/ | ||
|
|
||
| # PyBuilder | ||
| .pybuilder/ | ||
| target/ | ||
|
|
||
| # Jupyter Notebook | ||
| .ipynb_checkpoints | ||
|
|
||
| # IPython | ||
| profile_default/ | ||
| ipython_config.py | ||
|
|
||
| # pyenv | ||
| # For a library or package, you might want to ignore these files since the code is | ||
| # intended to run in multiple environments; otherwise, check them in: | ||
| # .python-version | ||
|
|
||
| # pipenv | ||
| # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | ||
| # However, in case of collaboration, if having platform-specific dependencies or dependencies | ||
| # having no cross-platform support, pipenv may install dependencies that don't work, or not | ||
| # install all needed dependencies. | ||
| #Pipfile.lock | ||
|
|
||
| # poetry | ||
| # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. | ||
| # This is especially recommended for binary packages to ensure reproducibility, and is more | ||
| # commonly ignored for libraries. | ||
| # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control | ||
| #poetry.lock | ||
|
|
||
| # pdm | ||
| # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. | ||
| #pdm.lock | ||
| # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it | ||
| # in version control. | ||
| # https://pdm.fming.dev/#use-with-ide | ||
| .pdm.toml | ||
|
|
||
| # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm | ||
| __pypackages__/ | ||
|
|
||
| # Celery stuff | ||
| celerybeat-schedule | ||
| celerybeat.pid | ||
|
|
||
| # SageMath parsed files | ||
| *.sage.py | ||
|
|
||
| # Environments | ||
| .env | ||
| .venv | ||
| env/ | ||
| venv/ | ||
| ENV/ | ||
| env.bak/ | ||
| venv.bak/ | ||
|
|
||
| # Spyder project settings | ||
| .spyderproject | ||
| .spyproject | ||
|
|
||
| # Rope project settings | ||
| .ropeproject | ||
|
|
||
| # mkdocs documentation | ||
| /site | ||
|
|
||
| # mypy | ||
| .mypy_cache/ | ||
| .dmypy.json | ||
| dmypy.json | ||
|
|
||
| # Pyre type checker | ||
| .pyre/ | ||
|
|
||
| # pytype static type analyzer | ||
| .pytype/ | ||
|
|
||
| # Cython debug symbols | ||
| cython_debug/ | ||
|
|
||
| # PyCharm | ||
| # JetBrains specific template is maintained in a separate JetBrains.gitignore that can | ||
| # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore | ||
| # and can be added to the global gitignore or merged into this file. For a more nuclear | ||
| # option (not recommended) you can uncomment the following to ignore the entire idea folder. | ||
| #.idea/ | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.