Skip to content

Commit 1254b98

Browse files
Merge pull request #75 from askui/refactor/openrouter-improvements
refactor: OpenRouter model integration and settings
2 parents b3613cb + 11c26a6 commit 1254b98

File tree

14 files changed

+471
-94
lines changed

14 files changed

+471
-94
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -446,15 +446,15 @@ You can use Vision Agent with [OpenRouter](https://openrouter.ai/) to access a w
446446
```python
447447
from askui import VisionAgent
448448
from askui.models import (
449-
OpenRouterGetModel,
449+
OpenRouterModel,
450450
OpenRouterSettings,
451451
ModelRegistry,
452452
)
453453

454454

455455
# Register OpenRouter model in the registry
456456
custom_models: ModelRegistry = {
457-
"my-custom-model": OpenRouterGetModel(
457+
"my-custom-model": OpenRouterModel(
458458
OpenRouterSettings(
459459
model="anthropic/claude-opus-4",
460460
)

src/askui/models/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
OnMessageCb,
1212
Point,
1313
)
14-
from .openrouter.handler import OpenRouterGetModel
14+
from .openrouter.model import OpenRouterModel
1515
from .openrouter.settings import OpenRouterSettings
1616
from .shared.computer_agent_message_param import (
1717
Base64ImageSourceParam,
@@ -28,6 +28,7 @@
2828
ToolUseBlockParam,
2929
UrlImageSourceParam,
3030
)
31+
from .shared.settings import ChatCompletionsCreateSettings
3132

3233
__all__ = [
3334
"ActModel",
@@ -54,6 +55,7 @@
5455
"ToolResultBlockParam",
5556
"ToolUseBlockParam",
5657
"UrlImageSourceParam",
57-
"OpenRouterGetModel",
58+
"OpenRouterModel",
5859
"OpenRouterSettings",
60+
"ChatCompletionsCreateSettings",
5961
]
Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
ModelName,
2222
Point,
2323
)
24+
from askui.models.shared.prompts import SYSTEM_PROMPT_GET, build_system_prompt_locate
2425
from askui.models.types.response_schemas import ResponseSchema
2526
from askui.utils.image_utils import (
2627
ImageSource,
@@ -47,8 +48,8 @@ def _inference(
4748
) -> list[anthropic.types.ContentBlock]:
4849
message = self._client.messages.create(
4950
model=model,
50-
max_tokens=self._settings.max_tokens,
51-
temperature=self._settings.temperature,
51+
max_tokens=self._settings.chat_completions_create_settings.max_tokens,
52+
temperature=self._settings.chat_completions_create_settings.temperature,
5253
system=system_prompt,
5354
messages=[
5455
{
@@ -87,12 +88,11 @@ def locate(
8788
prompt = f"Click on {locator_serialized}"
8889
screen_width = self._settings.resolution[0]
8990
screen_height = self._settings.resolution[1]
90-
system_prompt = f"Use a mouse and keyboard to interact with a computer, and take screenshots.\n* This is an interface to a desktop GUI. You do not have access to a terminal or applications menu. You must click on desktop icons to start applications.\n* Some applications may take time to start or process actions, so you may need to wait and take successive screenshots to see the results of your actions. E.g. if you click on Firefox and a window doesn't open, try taking another screenshot.\n* The screen's resolution is {screen_width}x{screen_height}.\n* The display number is 0\n* Whenever you intend to move the cursor to click on an element like an icon, you should consult a screenshot to determine the coordinates of the element before moving the cursor.\n* If you tried clicking on a program or link but it failed to load, even after waiting, try adjusting your cursor position so that the tip of the cursor visually falls on the element that you want to click.\n* Make sure to click any buttons, links, icons, etc with the cursor tip in the center of the element. Don't click boxes on their edges unless asked.\n" # noqa: E501
9191
scaled_image = scale_image_with_padding(image.root, screen_width, screen_height)
9292
response = self._inference(
9393
image_to_base64(scaled_image),
9494
prompt,
95-
system_prompt,
95+
build_system_prompt_locate(str(screen_width), str(screen_height)),
9696
model=ANTHROPIC_MODEL_NAME_MAPPING[ModelName(model_choice)],
9797
)
9898
assert len(response) > 0
@@ -129,11 +129,10 @@ def get(
129129
max_width=self._settings.resolution[0],
130130
max_height=self._settings.resolution[1],
131131
)
132-
system_prompt = "You are an agent to process screenshots and answer questions about things on the screen or extract information from it. Answer only with the response to the question and keep it short and precise." # noqa: E501
133132
response = self._inference(
134133
base64_image=image_to_base64(scaled_image),
135134
prompt=query,
136-
system_prompt=system_prompt,
135+
system_prompt=SYSTEM_PROMPT_GET,
137136
model=ANTHROPIC_MODEL_NAME_MAPPING[ModelName(model_choice)],
138137
)
139138
if len(response) == 0:

src/askui/models/anthropic/settings.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from pydantic_settings import BaseSettings
33

44
from askui.models.shared.computer_agent import ComputerAgentSettingsBase
5+
from askui.models.shared.settings import ChatCompletionsCreateSettings
56

67
COMPUTER_USE_BETA_FLAG = "computer-use-2024-10-22"
78

@@ -20,8 +21,10 @@ class ClaudeSettingsBase(BaseModel):
2021

2122
class ClaudeSettings(ClaudeSettingsBase):
2223
resolution: tuple[int, int] = Field(default_factory=lambda: (1280, 800))
23-
max_tokens: int = 1000
24-
temperature: float = 0.0
24+
chat_completions_create_settings: ChatCompletionsCreateSettings = Field(
25+
default_factory=ChatCompletionsCreateSettings,
26+
description="Settings for ChatCompletions",
27+
)
2528

2629

2730
class ClaudeComputerAgentSettings(ComputerAgentSettingsBase, ClaudeSettingsBase):

src/askui/models/model_router.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141

4242
from ..logger import logger
4343
from .anthropic.computer_agent import ClaudeComputerAgent
44-
from .anthropic.handler import ClaudeHandler
44+
from .anthropic.model import ClaudeHandler
4545
from .askui.inference_api import AskUiInferenceApi, AskUiSettings
4646

4747

src/askui/models/openrouter/handler.py

Lines changed: 0 additions & 77 deletions
This file was deleted.
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
import json
2+
from typing import TYPE_CHECKING, Any, Optional, Type
3+
4+
import openai
5+
from openai import OpenAI
6+
from typing_extensions import override
7+
8+
from askui.logger import logger
9+
from askui.models.exceptions import QueryNoResponseError
10+
from askui.models.models import GetModel
11+
from askui.models.shared.prompts import SYSTEM_PROMPT_GET
12+
from askui.models.types.response_schemas import ResponseSchema, to_response_schema
13+
from askui.utils.image_utils import ImageSource
14+
15+
from .settings import OpenRouterSettings
16+
17+
if TYPE_CHECKING:
18+
from openai.types.chat.completion_create_params import ResponseFormat
19+
20+
21+
def _clean_schema_refs(schema: dict[str, Any] | list[Any]) -> None:
22+
"""Remove title fields that are at the same level as $ref fields as they are not supported by OpenAI.""" # noqa: E501
23+
if isinstance(schema, dict):
24+
if "$ref" in schema and "title" in schema:
25+
del schema["title"]
26+
for value in schema.values():
27+
if isinstance(value, (dict, list)):
28+
_clean_schema_refs(value)
29+
elif isinstance(schema, list):
30+
for item in schema:
31+
if isinstance(item, (dict, list)):
32+
_clean_schema_refs(item)
33+
34+
35+
class OpenRouterModel(GetModel):
36+
"""
37+
This class implements the GetModel interface for the OpenRouter API.
38+
39+
Args:
40+
settings (OpenRouterSettings): The settings for the OpenRouter model.
41+
42+
Example:
43+
```python
44+
from askui import VisionAgent
45+
from askui.models import (
46+
OpenRouterModel,
47+
OpenRouterSettings,
48+
ModelRegistry,
49+
)
50+
51+
52+
# Register OpenRouter model in the registry
53+
custom_models: ModelRegistry = {
54+
"my-custom-model": OpenRouterGetModel(
55+
OpenRouterSettings(
56+
model="anthropic/claude-opus-4",
57+
)
58+
),
59+
}
60+
61+
with VisionAgent(models=custom_models, model={"get":"my-custom-model"}) as agent:
62+
result = agent.get("What is the main heading on the screen?")
63+
print(result)
64+
```
65+
""" # noqa: E501
66+
67+
def __init__(
68+
self,
69+
settings: OpenRouterSettings | None = None,
70+
client: Optional[OpenAI] = None,
71+
):
72+
self._settings = settings or OpenRouterSettings()
73+
74+
self._client = (
75+
client
76+
if client is not None
77+
else OpenAI(
78+
api_key=self._settings.open_router_api_key.get_secret_value(),
79+
base_url=str(self._settings.base_url),
80+
)
81+
)
82+
83+
def _predict(
84+
self,
85+
image_url: str,
86+
instruction: str,
87+
prompt: str,
88+
response_schema: type[ResponseSchema] | None,
89+
) -> str | None | ResponseSchema:
90+
extra_body: dict[str, object] = {}
91+
92+
if len(self._settings.models) > 0:
93+
extra_body["models"] = self._settings.models
94+
95+
_response_schema = (
96+
to_response_schema(response_schema) if response_schema else None
97+
)
98+
99+
response_format: openai.NotGiven | ResponseFormat = openai.NOT_GIVEN
100+
if _response_schema is not None:
101+
extra_body["provider"] = {"require_parameters": True}
102+
schema = _response_schema.model_json_schema()
103+
_clean_schema_refs(schema)
104+
105+
defs = schema.pop("$defs", None)
106+
schema_response_wrapper = {
107+
"type": "object",
108+
"properties": {"response": schema},
109+
"additionalProperties": False,
110+
"required": ["response"],
111+
}
112+
if defs:
113+
schema_response_wrapper["$defs"] = defs
114+
response_format = {
115+
"type": "json_schema",
116+
"json_schema": {
117+
"name": "user_json_schema",
118+
"schema": schema_response_wrapper,
119+
"strict": True,
120+
},
121+
}
122+
123+
chat_completion = self._client.chat.completions.create(
124+
model=self._settings.model,
125+
extra_body=extra_body,
126+
response_format=response_format,
127+
messages=[
128+
{
129+
"role": "user",
130+
"content": [
131+
{
132+
"type": "image_url",
133+
"image_url": {
134+
"url": image_url,
135+
},
136+
},
137+
{"type": "text", "text": prompt + instruction},
138+
],
139+
}
140+
],
141+
stream=False,
142+
top_p=self._settings.chat_completions_create_settings.top_p,
143+
temperature=self._settings.chat_completions_create_settings.temperature,
144+
max_tokens=self._settings.chat_completions_create_settings.max_tokens,
145+
seed=self._settings.chat_completions_create_settings.seed,
146+
stop=self._settings.chat_completions_create_settings.stop,
147+
frequency_penalty=self._settings.chat_completions_create_settings.frequency_penalty,
148+
presence_penalty=self._settings.chat_completions_create_settings.presence_penalty,
149+
)
150+
151+
model_response = chat_completion.choices[0].message.content
152+
153+
if _response_schema is not None and model_response is not None:
154+
try:
155+
response_json = json.loads(model_response)
156+
except json.JSONDecodeError:
157+
error_msg = f"Expected JSON, but model {self._settings.model} returned: {model_response}" # noqa: E501
158+
logger.error(error_msg)
159+
raise ValueError(error_msg) from None
160+
161+
validated_response = _response_schema.model_validate(
162+
response_json["response"]
163+
)
164+
return validated_response.root
165+
166+
return model_response
167+
168+
@override
169+
def get(
170+
self,
171+
query: str,
172+
image: ImageSource,
173+
response_schema: Type[ResponseSchema] | None,
174+
model_choice: str,
175+
) -> ResponseSchema | str:
176+
response = self._predict(
177+
image_url=image.to_data_url(),
178+
instruction=query,
179+
prompt=SYSTEM_PROMPT_GET,
180+
response_schema=response_schema,
181+
)
182+
if response is None:
183+
error_msg = f'No response from model "{model_choice}" to query: "{query}"'
184+
raise QueryNoResponseError(error_msg, query)
185+
return response

src/askui/models/openrouter/prompts.py

Lines changed: 0 additions & 1 deletion
This file was deleted.

0 commit comments

Comments
 (0)