Skip to content

Commit 250943f

Browse files
abrichrCody-DV
andauthored
feat(VisualReplayStrategy): adapters: ultralytics, som, anthropic, google; remove_move_before_click; vision.py
* add prompts/, adapters/openai.py, strategies/visual.py (wip) * adapters.anthropic * add anthropic.py * prompt with active segment descriptions * Set-of-Mark Prompting Adapter (#612) * Update openadapt/config.py * remove_move_before_click * started_counter; adapters.ultralytics * add vision.py * add openadapt/adapters/google.py * filter_masks_by_size * documentation * update README * add ultralytics * exclude alembic in black/flake8 * exclude .venv in black/flake8 * disable som adapter; remove logging * add adapters.google --------- Co-authored-by: Cody DeVilliers <[email protected]>
1 parent cc645c4 commit 250943f

32 files changed

+3817
-1269
lines changed

.flake8

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ exclude =
44
.venv
55
docstring-convention = google
66
max-line-length = 88
7-
extend-ignore = ANN101
7+
extend-ignore = ANN101, E203

.github/workflows/main.yml

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ jobs:
2525
steps:
2626
- name: Checkout code
2727
uses: actions/checkout@v3
28+
with:
29+
ref: ${{ env.BRANCH }}
30+
repository: ${{ env.REPO }}
2831

2932
- name: Set up Python
3033
uses: actions/setup-python@v3
@@ -35,12 +38,6 @@ jobs:
3538
if: matrix.os == 'macos-latest'
3639
run: sh install/install_openadapt.sh
3740

38-
- name: Checkout code
39-
uses: actions/checkout@v3
40-
with:
41-
ref: ${{ env.BRANCH }}
42-
repository: ${{ env.REPO }}
43-
4441
- name: Install poetry
4542
uses: snok/install-poetry@v1
4643
with:
@@ -63,7 +60,7 @@ jobs:
6360
if: steps.cache-deps.outputs.cache-hit == 'true'
6461

6562
- name: Check formatting with Black
66-
run: poetry run black --preview --check .
63+
run: poetry run black --preview --check . --exclude '/(alembic|\.venv)/'
6764

6865
- name: Run Flake8
69-
run: poetry run flake8
66+
run: poetry run flake8 --exclude=alembic,.venv

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,11 @@ python -m openadapt.replay NaiveReplayStrategy
189189
Other replay strategies include:
190190

191191
- [`StatefulReplayStrategy`](https://github.com/OpenAdaptAI/OpenAdapt/blob/main/openadapt/strategies/stateful.py): Proof-of-concept which uses the OpenAI GPT-4 API with prompts constructed via OS-level window data.
192+
- [`VisualReplayStrategy`](https://github.com/OpenAdaptAI/OpenAdapt/blob/main/openadapt/strategies/visual.py): Uses [Fast Segment Anything Model (FastSAM)](https://github.com/CASIA-IVA-Lab/FastSAM) to segment active window. Accepts an "instructions" parameter that is used to modify the recording, e.g.:
193+
194+
```
195+
python -m openadapt.replay VisualReplayStrategy --instructions "Multiply 9x5 instead of 6x8"
196+
```
192197

193198
See https://github.com/OpenAdaptAI/OpenAdapt/tree/main/openadapt/strategies for a complete list. More ReplayStrategies coming soon! (see [Contributing](#Contributing)).
194199

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
"""add active_segment_description and available_segment_descriptions
2+
3+
Revision ID: 30a5ba9d6453
4+
Revises: 530f0663324e
5+
Create Date: 2024-04-05 12:02:57.843244
6+
7+
"""
8+
from alembic import op
9+
import sqlalchemy as sa
10+
11+
12+
# revision identifiers, used by Alembic.
13+
revision = '30a5ba9d6453'
14+
down_revision = '530f0663324e'
15+
branch_labels = None
16+
depends_on = None
17+
18+
19+
def upgrade() -> None:
20+
# ### commands auto generated by Alembic - please adjust! ###
21+
with op.batch_alter_table('action_event', schema=None) as batch_op:
22+
batch_op.add_column(sa.Column('active_segment_description', sa.String(), nullable=True))
23+
batch_op.add_column(sa.Column('available_segment_descriptions', sa.String(), nullable=True))
24+
25+
# ### end Alembic commands ###
26+
27+
28+
def downgrade() -> None:
29+
# ### commands auto generated by Alembic - please adjust! ###
30+
with op.batch_alter_table('action_event', schema=None) as batch_op:
31+
batch_op.drop_column('available_segment_descriptions')
32+
batch_op.drop_column('active_segment_description')
33+
34+
# ### end Alembic commands ###

openadapt/adapters/__init__.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
"""Adapters for completion and segmentation."""
2+
3+
from types import ModuleType
4+
5+
from openadapt import config
6+
from . import anthropic
7+
from . import openai
8+
from . import replicate
9+
from . import som
10+
from . import ultralytics
11+
from . import google
12+
13+
14+
def get_default_prompt_adapter() -> ModuleType:
15+
"""Returns the default prompt adapter module.
16+
17+
Returns:
18+
The module corresponding to the default prompt adapter.
19+
"""
20+
return {
21+
"openai": openai,
22+
"anthropic": anthropic,
23+
"google": google,
24+
}[config.DEFAULT_ADAPTER]
25+
26+
27+
def get_default_segmentation_adapter() -> ModuleType:
28+
"""Returns the default image segmentation adapter module.
29+
30+
Returns:
31+
The module corresponding to the default segmentation adapter.
32+
"""
33+
return {
34+
"som": som,
35+
"replicate": replicate,
36+
"ultralytics": ultralytics,
37+
}[config.DEFAULT_SEGMENTATION_ADAPTER]
38+
39+
40+
__all__ = ["anthropic", "openai", "replicate", "som", "ultralytics", "google"]

openadapt/adapters/anthropic.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
"""Adapter for Anthropic API with vision support."""
2+
3+
from pprint import pprint
4+
5+
from loguru import logger
6+
import anthropic
7+
8+
from openadapt import cache, config
9+
10+
11+
MAX_TOKENS = 4096
12+
# from https://docs.anthropic.com/claude/docs/vision
13+
MAX_IMAGES = 20
14+
MODEL_NAME = "claude-3-opus-20240229"
15+
16+
17+
@cache.cache()
18+
def create_payload(
19+
prompt: str,
20+
system_prompt: str | None = None,
21+
base64_images: list[tuple[str, str]] | None = None,
22+
model: str = MODEL_NAME,
23+
max_tokens: int | None = None,
24+
) -> dict:
25+
"""Creates the payload for the Anthropic API request with image support."""
26+
messages = []
27+
28+
user_message_content = []
29+
30+
max_tokens = max_tokens or MAX_TOKENS
31+
if max_tokens > MAX_TOKENS:
32+
logger.warning(f"{max_tokens=} > {MAX_TOKENS=}")
33+
max_tokens = MAX_TOKENS
34+
35+
# Add base64 encoded images to the user message content
36+
if base64_images:
37+
for image_data in base64_images:
38+
# Extract media type and base64 data
39+
media_type, base64_str = image_data.split(";base64,", 1)
40+
media_type = media_type.split(":")[-1] # Remove 'data:' prefix
41+
42+
user_message_content.append(
43+
{
44+
"type": "image",
45+
"source": {
46+
"type": "base64",
47+
"media_type": media_type,
48+
"data": base64_str,
49+
},
50+
}
51+
)
52+
53+
# Add text prompt
54+
user_message_content.append(
55+
{
56+
"type": "text",
57+
"text": prompt,
58+
}
59+
)
60+
61+
# Construct user message
62+
messages.append(
63+
{
64+
"role": "user",
65+
"content": user_message_content,
66+
}
67+
)
68+
69+
# Prepare the full payload
70+
payload = {
71+
"model": model,
72+
"max_tokens": max_tokens,
73+
"messages": messages,
74+
}
75+
76+
# Add system_prompt as a top-level parameter if provided
77+
if system_prompt:
78+
payload["system"] = system_prompt
79+
80+
return payload
81+
82+
83+
client = anthropic.Anthropic(api_key=config.ANTHROPIC_API_KEY)
84+
85+
86+
@cache.cache()
87+
def get_completion(payload: dict) -> str:
88+
"""Sends a request to the Anthropic API and returns the response."""
89+
try:
90+
response = client.messages.create(**payload)
91+
except Exception as exc:
92+
logger.exception(exc)
93+
import ipdb
94+
95+
ipdb.set_trace()
96+
"""
97+
Message(
98+
id='msg_01L55ai2A9q92687mmjMSch3',
99+
content=[
100+
ContentBlock(
101+
text='{
102+
"action": [
103+
{
104+
"name": "press",
105+
"key_name": "cmd",
106+
"canonical_key_name": "cmd"
107+
},
108+
...
109+
]
110+
}',
111+
type='text'
112+
)
113+
],
114+
model='claude-3-opus-20240229',
115+
role='assistant',
116+
stop_reason='end_turn',
117+
stop_sequence=None,
118+
type='message',
119+
usage=Usage(input_tokens=4379, output_tokens=109))
120+
"""
121+
texts = [content_block.text for content_block in response.content]
122+
return "\n".join(texts)
123+
124+
125+
def prompt(
126+
prompt: str,
127+
system_prompt: str | None = None,
128+
base64_images: list[str] | None = None,
129+
max_tokens: int | None = None,
130+
) -> str:
131+
"""Public method to get a response from the Anthropic API with image support."""
132+
if len(base64_images) > MAX_IMAGES:
133+
# XXX TODO handle this
134+
raise Exception(
135+
f"{len(base64_images)=} > {MAX_IMAGES=}. Use a different adapter."
136+
)
137+
payload = create_payload(
138+
prompt,
139+
system_prompt,
140+
base64_images,
141+
max_tokens=max_tokens,
142+
)
143+
# pprint(f"payload=\n{payload}") # Log payload for debugging
144+
result = get_completion(payload)
145+
pprint(f"result=\n{result}") # Log result for debugging
146+
return result

openadapt/adapters/google.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
"""Adapter for Google Gemini.
2+
3+
See https://ai.google.dev/tutorials/python_quickstart for documentation.
4+
"""
5+
6+
from pprint import pprint
7+
8+
from PIL import Image
9+
import fire
10+
import google.generativeai as genai
11+
12+
from openadapt import cache, config, utils
13+
14+
15+
MAX_TOKENS = 2**20 # 1048576
16+
MODEL_NAME = [
17+
"gemini-pro-vision",
18+
"models/gemini-1.5-pro-latest",
19+
][-1]
20+
# https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/send-multimodal-prompts
21+
MAX_IMAGES = {
22+
"gemini-pro-vision": 16,
23+
"models/gemini-1.5-pro-latest": 3000,
24+
}[MODEL_NAME]
25+
26+
27+
@cache.cache()
28+
def prompt(
29+
prompt: str,
30+
system_prompt: str | None = None,
31+
base64_images: list[str] | None = None,
32+
# max_tokens: int | None = None,
33+
model_name: str = MODEL_NAME,
34+
) -> str:
35+
"""Public method to get a response from the Google API with image support."""
36+
full_prompt = "\n\n###\n\n".join([s for s in (system_prompt, prompt) if s])
37+
# HACK
38+
full_prompt += "\nWhen responding in JSON, you MUST use double quotes around keys."
39+
40+
# TODO: modify API across all adapters to accept PIL.Image
41+
images = (
42+
[utils.utf82image(base64_image) for base64_image in base64_images]
43+
if base64_images
44+
else []
45+
)
46+
47+
genai.configure(api_key=config.GOOGLE_API_KEY)
48+
model = genai.GenerativeModel(model_name)
49+
response = model.generate_content([full_prompt] + images)
50+
response.resolve()
51+
pprint(f"response=\n{response}") # Log response for debugging
52+
return response.text
53+
54+
55+
def main(text: str, image_path: str | None = None) -> None:
56+
"""Prompt Google Gemini with text and a path to an image."""
57+
if image_path:
58+
with Image.open(image_path) as img:
59+
# Convert image to RGB if it's RGBA (to remove alpha channel)
60+
if img.mode in ("RGBA", "LA") or (
61+
img.mode == "P" and "transparency" in img.info
62+
):
63+
img = img.convert("RGB")
64+
base64_image = utils.image2utf8(img)
65+
else:
66+
base64_image = None
67+
68+
base64_images = [base64_image] if base64_image else None
69+
output = prompt(text, base64_images=base64_images)
70+
print(output)
71+
72+
73+
if __name__ == "__main__":
74+
fire.Fire(main)

0 commit comments

Comments
 (0)