|
1 | 1 | import httpx |
| 2 | +from anthropic.types.beta import ( |
| 3 | + BetaTextBlockParam, |
| 4 | + BetaToolChoiceParam, |
| 5 | + BetaToolUnionParam, |
| 6 | +) |
| 7 | +from pydantic import BaseModel, ConfigDict |
2 | 8 | from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential |
3 | 9 | from typing_extensions import override |
4 | 10 |
|
5 | 11 | from askui.models.askui.settings import AskUiComputerAgentSettings |
6 | | -from askui.models.shared.computer_agent import ComputerAgent |
| 12 | +from askui.models.shared.computer_agent import ComputerAgent, ThinkingConfigParam |
7 | 13 | from askui.models.shared.computer_agent_message_param import MessageParam |
8 | 14 | from askui.models.shared.tools import ToolCollection |
9 | 15 | from askui.reporting import Reporter |
10 | 16 |
|
11 | 17 | from ...logger import logger |
12 | 18 |
|
13 | 19 |
|
| 20 | +class RequestBody(BaseModel): |
| 21 | + model_config = ConfigDict(arbitrary_types_allowed=True) |
| 22 | + max_tokens: int |
| 23 | + messages: list[MessageParam] |
| 24 | + model: str |
| 25 | + tools: list[BetaToolUnionParam] |
| 26 | + betas: list[str] |
| 27 | + system: list[BetaTextBlockParam] |
| 28 | + thinking: ThinkingConfigParam |
| 29 | + tool_choice: BetaToolChoiceParam |
| 30 | + |
| 31 | + |
14 | 32 | def is_retryable_error(exception: BaseException) -> bool: |
15 | 33 | """Check if the exception is a retryable error (status codes 429 or 529).""" |
16 | 34 | if isinstance(exception, httpx.HTTPStatusError): |
@@ -47,21 +65,31 @@ def _create_message( |
47 | 65 | model_choice: str, # noqa: ARG002 |
48 | 66 | ) -> MessageParam: |
49 | 67 | try: |
50 | | - request_body = { |
51 | | - "max_tokens": self._settings.max_tokens, |
52 | | - "messages": [msg.model_dump(mode="json") for msg in messages], |
53 | | - "model": self._settings.model, |
54 | | - "tools": self._tool_collection.to_params(), |
55 | | - "betas": self._settings.betas, |
56 | | - "system": [self._system], |
57 | | - } |
| 68 | + request_body = RequestBody( |
| 69 | + max_tokens=self._settings.max_tokens, |
| 70 | + messages=messages, |
| 71 | + model=self._settings.model, |
| 72 | + tools=self._tool_collection.to_params(), |
| 73 | + betas=self._settings.betas, |
| 74 | + system=[self._system], |
| 75 | + tool_choice=self._settings.tool_choice, |
| 76 | + thinking=self._settings.thinking, |
| 77 | + ) |
58 | 78 | response = self._client.post( |
59 | | - "/act/inference", json=request_body, timeout=300.0 |
| 79 | + "/act/inference", |
| 80 | + json=request_body.model_dump( |
| 81 | + mode="json", exclude={"messages": {"stop_reason"}} |
| 82 | + ), |
| 83 | + timeout=300.0, |
60 | 84 | ) |
61 | 85 | response.raise_for_status() |
62 | | - response_data = response.json() |
63 | | - return MessageParam.model_validate(response_data) |
| 86 | + return MessageParam.model_validate_json(response.text) |
64 | 87 | except Exception as e: # noqa: BLE001 |
65 | 88 | if is_retryable_error(e): |
66 | 89 | logger.debug(e) |
| 90 | + if ( |
| 91 | + isinstance(e, httpx.HTTPStatusError) |
| 92 | + and 400 <= e.response.status_code < 500 |
| 93 | + ): |
| 94 | + raise ValueError(e.response.json()) from e |
67 | 95 | raise |
0 commit comments