Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions examples/run_storm_with_tavily.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import os
from argparse import ArgumentParser

from knowledge_storm import (
STORMWikiRunnerArguments,
STORMWikiRunner,
STORMWikiLMConfigs,
)
from knowledge_storm.lm import OpenAIModel
from knowledge_storm.rm import TavilySearchRM
from knowledge_storm.utils import load_api_key


def main(args):
load_api_key(toml_file_path="secrets.toml")
lm_configs = STORMWikiLMConfigs()
openai_kwargs = {
"api_key": os.getenv("OPENAI_API_KEY"),
"temperature": 1.0,
"top_p": 0.9,
}

# STORM is a knowledge curation system which consumes valid information from the internet.
# We use TavilySearchRM to retrieve information from the internet.
# Please set TAVILY_API_KEY in your environment variables or secrets.toml
tavily_rm = TavilySearchRM(
tavily_search_api_key=os.getenv("TAVILY_API_KEY"), k=args.retrieve_top_k
)

# Initialize the engine runner
engine_args = STORMWikiRunnerArguments(
output_dir=args.output_dir,
max_conv_turn=args.max_conv_turn,
max_perspective=args.max_perspective,
search_top_k=args.search_top_k,
max_thread_num=args.max_thread_num,
)
runner = STORMWikiRunner(engine_args, lm_configs, tavily_rm)

topic = input("Topic: ")
runner.run(
topic=topic,
do_research=args.do_research,
do_generate_outline=args.do_generate_outline,
do_generate_article=args.do_generate_article,
do_polish_article=args.do_polish_article,
)
runner.post_run()
runner.summary()


if __name__ == "__main__":
parser = ArgumentParser()
# global arguments
parser.add_argument(
"--output-dir",
type=str,
default="./results/tavily_storm",
help="Directory to store the outputs.",
)
parser.add_argument(
"--max-conv-turn",
type=int,
default=3,
help="Maximum number of questions in conversational questioning.",
)
parser.add_argument(
"--max-perspective",
type=int,
default=3,
help="Maximum number of perspectives to consider in perspective-guided questioning.",
)
parser.add_argument(
"--search-top-k",
type=int,
default=3,
help="Top k search results to consider for each search query.",
)
parser.add_argument(
"--retrieve-top-k",
type=int,
default=3,
help="Top k collected search results to retrieve for each section.",
)
parser.add_argument(
"--max-thread-num",
type=int,
default=10,
help="Maximum number of threads to use.",
)
parser.add_argument(
"--do-research",
action="store_true",
help="If True, simulate conversation to research the topic; otherwise, load the results.",
)
parser.add_argument(
"--do-generate-outline",
action="store_true",
help="If True, generate an outline for the topic; otherwise, load the results.",
)
parser.add_argument(
"--do-generate-article",
action="store_true",
help="If True, generate an article for the topic; otherwise, load the results.",
)
parser.add_argument(
"--do-polish-article",
action="store_true",
help="If True, polish the article by adding a summarization section and (optionally) removing duplicate content.",
)

main(parser.parse_args())
77 changes: 77 additions & 0 deletions knowledge_storm/rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,3 +1236,80 @@ def forward(
logging.error(f"Error occurs when searching query {query}: {e}")

return collected_results


class TavilySearchRM(dspy.Retrieve):
def __init__(
self, tavily_search_api_key=None, k=3, is_valid_source: Callable = None
):
super().__init__(k=k)
if not tavily_search_api_key and not os.environ.get("TAVILY_API_KEY"):
raise RuntimeError(
"You must supply tavily_search_api_key or set environment variable TAVILY_API_KEY"
)
elif tavily_search_api_key:
self.tavily_search_api_key = tavily_search_api_key
else:
self.tavily_search_api_key = os.environ["TAVILY_API_KEY"]
self.usage = 0

# If not None, is_valid_source shall be a function that takes a URL and returns a boolean.
if is_valid_source:
self.is_valid_source = is_valid_source
else:
self.is_valid_source = lambda x: True

def get_usage_and_reset(self):
usage = self.usage
self.usage = 0

return {"TavilySearchRM": usage}

def forward(
self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = []
):
"""Search with Tavily for self.k top passages for query or queries

Args:
query_or_queries (Union[str, List[str]]): The query or queries to search for.
exclude_urls (List[str]): A list of urls to exclude from the search results.

Returns:
a list of Dicts, each dict has keys of 'description', 'snippets' (list of strings), 'title', 'url'
"""
queries = (
[query_or_queries]
if isinstance(query_or_queries, str)
else query_or_queries
)
self.usage += len(queries)
collected_results = []
for query in queries:
try:
response = requests.post(
"https://api.tavily.com/search",
json={
"api_key": self.tavily_search_api_key,
"query": query,
"max_results": self.k,
"include_raw_content": True,
},
).json()

if "results" not in response:
continue

for r in response["results"]:
if self.is_valid_source(r["url"]) and r["url"] not in exclude_urls:
collected_results.append(
{
"description": r.get("content", ""),
"snippets": [r.get("content", "")],
"title": r.get("title", ""),
"url": r.get("url", ""),
}
)
except Exception as e:
logging.error(f"Error occurs when searching query {query}: {e}")

return collected_results
51 changes: 51 additions & 0 deletions knowledge_storm/test_tavily.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import os
import unittest
from unittest.mock import MagicMock, patch
from knowledge_storm.rm import TavilySearchRM

class TestTavilySearchRM(unittest.TestCase):
def setUp(self):
self.api_key = "test_api_key"
os.environ["TAVILY_API_KEY"] = self.api_key

def test_init(self):
rm = TavilySearchRM(tavily_search_api_key=self.api_key, k=5)
self.assertEqual(rm.tavily_search_api_key, self.api_key)
self.assertEqual(rm.k, 5)

@patch("requests.post")
def test_forward(self, mock_post):
# Mock response
mock_response = MagicMock()
mock_response.json.return_value = {
"results": [
{
"url": "http://example.com/1",
"title": "Example Title 1",
"content": "Example Content 1",
"raw_content": "Raw Content 1"
},
{
"url": "http://example.com/2",
"title": "Example Title 2",
"content": "Example Content 2",
"raw_content": "Raw Content 2"
}
]
}
mock_post.return_value = mock_response

rm = TavilySearchRM(tavily_search_api_key=self.api_key, k=2)
results = rm.forward(query_or_queries="test query")

self.assertEqual(len(results), 2)
self.assertEqual(results[0]["url"], "http://example.com/1")
self.assertEqual(results[0]["snippets"], ["Example Content 1"])
self.assertEqual(results[0]["description"], "Example Content 1")

# Verify usage update
usage = rm.get_usage_and_reset()
self.assertEqual(usage["TavilySearchRM"], 1)

if __name__ == "__main__":
unittest.main()