Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/transformers/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@


_import_structure = {
"agents": ["Agent", "HfAgent", "OpenAiAgent"],
"agents": ["Agent", "HfAgent", "OpenAiAgent", "AzureOpenAiAgent"],
"base": ["PipelineTool", "RemoteTool", "Tool", "launch_gradio_demo", "load_tool"],
}

Expand All @@ -46,7 +46,7 @@
_import_structure["translation"] = ["TranslationTool"]

if TYPE_CHECKING:
from .agents import Agent, HfAgent, OpenAiAgent
from .agents import Agent, HfAgent, OpenAiAgent, AzureOpenAiAgent
from .base import PipelineTool, RemoteTool, Tool, launch_gradio_demo, load_tool

try:
Expand Down
126 changes: 126 additions & 0 deletions src/transformers/tools/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,132 @@ def _completion_generate(self, prompts, stop):
)
return [answer["text"] for answer in result["choices"]]

class AzureOpenAiAgent(Agent):
"""
Agent that uses the Azure OpenAI API to generate code.

<Tip warning={true}>

The Azure OpenAI models are used in generation mode, so even for the `chat()` API, it's better to use models like
`"text-davinci-003"` over the chat-GPT variant. Proper support for chat-GPT models will come in a next version.

</Tip>

Args:
deployment_id (`str`, *optional*, defaults to `"text-davinci-003"`):
The Azure deployment_id of the OpenAI model to use.
api_key (`str`, *optional*):
The API key to use. If unset, will look for the environment variable `"OPENAI_API_KEY"`.
api_base(`str`, *optional*):
The base URL for the your Azure deployment like "https://<deployment>.openai.azure.com/".
If unset, will look for the environment variable `"OPENAI_API_BASE"`.
api_version(`str`, *optional*):
The version of the API to use.
If unset, will look for the environment variable `"OPENAI_API_VERSION"`.
is_chat_model (`bool`, *optional*):
Whether the model is a chat model or not.
chat_prompt_template (`str`, *optional*):
Pass along your own prompt if you want to override the default template for the `chat` method.
run_prompt_template (`str`, *optional*):
Pass along your own prompt if you want to override the default template for the `run` method.
additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*):
Any additional tools to include on top of the default ones. If you pass along a tool with the same name as
one of the default tools, that default tool will be overridden.

Example:

```py
from transformers import AzureOpenAiAgent

agent = AzureOpenAiAgent(deployment_id="text-davinci-003", api_key=xxx, api_base="https://<deployment>.openai.azure.com/", api_version="2023-03-15-preview", is_chat_model=False)
agent.run("Is the following `text` (in Spanish) positive or negative?", text="¡Este es un API muy agradable!")
```
"""

def __init__(
self,
deployment_id="text-davinci-003",
api_key=None,
api_base=None,
api_version=None,
is_chat_model=False,
chat_prompt_template=None,
run_prompt_template=None,
additional_tools=None,
):
if not is_openai_available():
raise ImportError("Using `AzureOpenAiAgent` requires `openai`: `pip install openai`.")

openai.api_type = "azure"

if api_key is None:
api_key = os.environ.get("OPENAI_API_KEY", None)
if api_key is None:
raise ValueError(
"You need an openai key to use `AzureOpenAIAgent`. You can get one here: Get one here "
"https://azure.microsoft.com/en-us/products/cognitive-services/openai-service`."
"If you have one, set it in your env with `os.environ['OPENAI_API_KEY'] = xxx."
)
else:
openai.api_key = api_key
if api_base is None:
api_base = os.environ.get("OPENAI_API_BASE", None)
if api_base is None:
raise ValueError(
"You need an Azure OpenAI base url to use `AzureOpenAIAgent`."
"If you have one, set it in your env with `os.environ['OPENAI_API_BASE'] = https://<deployment>.openai.azure.com/"
)
else:
openai.api_base = api_base
if api_version is None:
api_version = os.environ.get("OPENAI_API_VERSION", None)
if api_version is None:
raise ValueError(
"You need an openai api version to use `AzureOpenAIAgent`."
"If you have one, set it in your env with `os.environ['OPENAI_API_VERSION'] = 2023-03-15-preview"
)
else:
openai.api_version = api_version

self.deployment_id = deployment_id
self.is_chat_model = is_chat_model
super().__init__(
chat_prompt_template=chat_prompt_template,
run_prompt_template=run_prompt_template,
additional_tools=additional_tools,
)

def generate_many(self, prompts, stop):
if self.is_chat_model:
return [self._chat_generate(prompt, stop) for prompt in prompts]
else:
return self._completion_generate(prompts, stop)

def generate_one(self, prompt, stop):
if self.is_chat_model:
return self._chat_generate(prompt, stop)
else:
return self._completion_generate([prompt], stop)[0]

def _chat_generate(self, prompt, stop):
result = openai.ChatCompletion.create(
engine=self.deployment_id,
messages=[{"role": "user", "content": prompt}],
temperature=0,
stop=stop,
)
return result["choices"][0]["message"]["content"]

def _completion_generate(self, prompts, stop):
result = openai.Completion.create(
engine=self.deployment_id,
prompt=prompts,
temperature=0,
stop=stop,
max_tokens=200,
)
return [answer["text"] for answer in result["choices"]]


class HfAgent(Agent):
"""
Expand Down