diff --git a/server/app.py b/server/app.py index a5f2f1d..d1a3a8e 100644 --- a/server/app.py +++ b/server/app.py @@ -5,7 +5,6 @@ import json import logging import os -from typing import Any from fastapi import FastAPI, Request from fastapi.exceptions import RequestValidationError @@ -24,6 +23,7 @@ ResponseEnvelope, ResponseSession, AssistantMessage, + JsonObject, build_response_from_plan, parse_request_ids, ) @@ -34,7 +34,7 @@ class JsonFormatter(logging.Formatter): def format(self, record: logging.LogRecord) -> str: - payload: dict[str, Any] = { + payload: JsonObject = { "level": record.levelname, "logger": record.name, "message": record.getMessage(), @@ -122,7 +122,7 @@ def build_error_payload( refinement: RefinementStatus | None, code: str, message: str, -) -> dict[str, Any]: +) -> JsonObject: payload = ResponseEnvelope( requestId=request_id, session=ResponseSession.model_validate(session), @@ -197,7 +197,7 @@ def _log_fulfilled_request( ) -def _encode_sse(event: str, payload: Mapping[str, Any]) -> str: +def _encode_sse(event: str, payload: Mapping[str, object]) -> str: return f"event: {event}\ndata: {json.dumps(payload, separators=(',', ':'))}\n\n" diff --git a/server/bridge_types.py b/server/bridge_types.py index 942f67b..9621360 100644 --- a/server/bridge_types.py +++ b/server/bridge_types.py @@ -1,8 +1,8 @@ from __future__ import annotations -from typing import Any, Protocol, TypedDict +from typing import Protocol, TypedDict -from shared.protocol import AgentPlan, RequestEnvelope +from shared.protocol import AgentPlan, JsonObject, RequestEnvelope class RequestProgressPayload(TypedDict): @@ -11,7 +11,7 @@ class RequestProgressPayload(TypedDict): toolCallsUsed: int maxToolCalls: int appliedOperationCount: int - operations: list[dict[str, Any]] + operations: list[JsonObject] message: str lastToolName: str | None progressVersion: int diff --git a/server/codex_bridge/apply_batch.py b/server/codex_bridge/apply_batch.py index 6a82df3..9051fd9 100644 --- a/server/codex_bridge/apply_batch.py +++ b/server/codex_bridge/apply_batch.py @@ -1,10 +1,10 @@ from __future__ import annotations -from collections.abc import Callable +from collections.abc import Callable, Sequence from dataclasses import dataclass -from typing import Any - +from typing import cast from shared.canonical_plan import CanonicalEditAction +from shared.protocol import JsonObject from .canonical_binder import bind_canonical_actions from .models import TurnContext @@ -12,17 +12,15 @@ @dataclass(frozen=True, slots=True) class PreparedApplyBatch: - normalized_batch: list[dict[str, Any]] + normalized_batch: list[JsonObject] render_warnings: list[str] def prepare_apply_batch( context: TurnContext, - arguments: dict[str, Any], + arguments: JsonObject, *, - normalize_operation: Callable[ - [dict[str, Any], int], tuple[dict[str, Any], str | None] - ], + normalize_operation: Callable[[JsonObject, int], tuple[JsonObject, str | None]], ) -> tuple[PreparedApplyBatch | None, str | None]: raw_operations = arguments.get("operations") raw_canonical_actions = arguments.get("canonicalActions") @@ -48,7 +46,7 @@ def prepare_apply_batch( def _prepare_canonical_batch( - context: TurnContext, raw_canonical_actions: list[Any] + context: TurnContext, raw_canonical_actions: Sequence[object] ) -> tuple[PreparedApplyBatch | None, str | None]: canonical_actions: list[CanonicalEditAction] = [] for raw_action in raw_canonical_actions: @@ -76,17 +74,15 @@ def _prepare_canonical_batch( def _prepare_raw_batch( context: TurnContext, - raw_operations: list[Any], - normalize_operation: Callable[ - [dict[str, Any], int], tuple[dict[str, Any], str | None] - ], + raw_operations: Sequence[object], + normalize_operation: Callable[[JsonObject, int], tuple[JsonObject, str | None]], ) -> tuple[PreparedApplyBatch | None, str | None]: - normalized_batch: list[dict[str, Any]] = [] + normalized_batch: list[JsonObject] = [] for index, raw_operation in enumerate(raw_operations): if not isinstance(raw_operation, dict): return None, "Every apply_operations entry must be an object." normalized_operation, error = normalize_operation( - raw_operation, + cast(JsonObject, raw_operation), context.next_operation_sequence + index, ) if error: diff --git a/server/codex_bridge/image_signals.py b/server/codex_bridge/image_signals.py index ff4864c..dee7e13 100644 --- a/server/codex_bridge/image_signals.py +++ b/server/codex_bridge/image_signals.py @@ -3,7 +3,7 @@ import base64 import binascii import io -from typing import Any, Literal +from typing import Literal, TypedDict from shared.analysis_signals import ( ActiveModuleSignal, @@ -12,7 +12,25 @@ RegionSignalSummary, TonalSignalSummary, ) -from shared.protocol import RequestEnvelope +from shared.protocol import JsonObject, RequestEnvelope + + +class PreviewPixel(TypedDict): + red: float + green: float + blue: float + luma: float + saturation: float + + +class PreviewSamples(TypedDict): + width: int + height: int + samples: list[PreviewPixel] + lumas: list[float] + saturations: list[float] + grayscale: list[float] + try: from PIL import Image @@ -41,7 +59,7 @@ def _decode_preview_bytes(request: RequestEnvelope) -> bytes | None: return None -def _preview_samples(image_bytes: bytes) -> dict[str, Any] | None: +def _preview_samples(image_bytes: bytes) -> PreviewSamples | None: if Image is None or not image_bytes: return None @@ -60,7 +78,7 @@ def _preview_samples(image_bytes: bytes) -> dict[str, Any] | None: grayscale: list[float] = [] lumas: list[float] = [] saturations: list[float] = [] - samples: list[dict[str, float]] = [] + samples: list[PreviewPixel] = [] for red, green, blue in pixels: luma = (0.2126 * red + 0.7152 * green + 0.0722 * blue) / 255.0 saturation = (max(red, green, blue) - min(red, green, blue)) / 255.0 @@ -87,7 +105,7 @@ def _preview_samples(image_bytes: bytes) -> dict[str, Any] | None: } -def _tonal_from_preview(preview_samples: dict[str, Any]) -> TonalSignalSummary: +def _tonal_from_preview(preview_samples: PreviewSamples) -> TonalSignalSummary: lumas = sorted(float(value) for value in preview_samples["lumas"]) samples = preview_samples["samples"] total = max(1, len(samples)) @@ -146,7 +164,7 @@ def _tonal_from_histogram(request: RequestEnvelope) -> TonalSignalSummary | None def _sharpness_estimate( - preview_samples: dict[str, Any] | None, + preview_samples: PreviewSamples | None, ) -> Literal["unknown", "soft", "normal", "crisp"]: if preview_samples is None: return "unknown" @@ -190,13 +208,13 @@ def _noise_risk( def _region_slice( - preview_samples: dict[str, Any], + preview_samples: PreviewSamples, *, x_start: float, x_end: float, y_start: float, y_end: float, -) -> list[dict[str, float]]: +) -> list[PreviewPixel]: width = int(preview_samples["width"]) height = int(preview_samples["height"]) samples = preview_samples["samples"] @@ -204,7 +222,7 @@ def _region_slice( right = max(left + 1, min(width, int(width * x_end))) top = max(0, min(height - 1, int(height * y_start))) bottom = max(top + 1, min(height, int(height * y_end))) - region: list[dict[str, float]] = [] + region: list[PreviewPixel] = [] for y in range(top, bottom): start = y * width + left end = y * width + right @@ -212,7 +230,7 @@ def _region_slice( return region -def _region_stats(region: list[dict[str, float]]) -> tuple[float, float]: +def _region_stats(region: list[PreviewPixel]) -> tuple[float, float]: if not region: return 0.0, 0.0 mean_luma = sum(sample["luma"] for sample in region) / len(region) @@ -221,7 +239,7 @@ def _region_stats(region: list[dict[str, float]]) -> tuple[float, float]: def _region_summaries( - preview_samples: dict[str, Any] | None, + preview_samples: PreviewSamples | None, ) -> list[RegionSignalSummary]: if preview_samples is None: return [] @@ -303,7 +321,7 @@ def _active_modules(request: RequestEnvelope) -> tuple[int, list[ActiveModuleSig return len(active_history), signals -def build_image_analysis_signals(request: RequestEnvelope) -> dict[str, Any]: +def build_image_analysis_signals(request: RequestEnvelope) -> JsonObject: preview_samples = _preview_samples(_decode_preview_bytes(request) or b"") tonal = _tonal_from_preview(preview_samples) if preview_samples else None if tonal is None: diff --git a/server/codex_bridge/models.py b/server/codex_bridge/models.py index eebe319..66447f8 100644 --- a/server/codex_bridge/models.py +++ b/server/codex_bridge/models.py @@ -2,9 +2,9 @@ import threading from dataclasses import dataclass, field -from typing import Any, TypedDict +from typing import TypedDict -from shared.protocol import AgentPlan, RequestEnvelope +from shared.protocol import AgentPlan, JsonObject, RequestEnvelope @dataclass(frozen=True, slots=True) @@ -60,19 +60,19 @@ class TurnContext: current_preview_bytes: bytes preview_mime_type: str base_image_revision_id: str - state_payload: dict[str, Any] - setting_by_id: dict[str, dict[str, Any]] + state_payload: JsonObject + setting_by_id: dict[str, JsonObject] base_float_setting_numbers: dict[str, float] live_run_enabled: bool max_tool_calls: int tool_calls_used: int = 0 consecutive_read_only_tool_calls: int = 0 - applied_operations: list[dict[str, Any]] = field(default_factory=list) + applied_operations: list[JsonObject] = field(default_factory=list) next_operation_sequence: int = 1 render_event: threading.Event = field(default_factory=threading.Event) rendered_preview_bytes: bytes | None = None requires_render_callback: bool = False - last_applied_batch: list[dict[str, Any]] = field(default_factory=list) + last_applied_batch: list[JsonObject] = field(default_factory=list) last_applied_summary: str | None = None last_verifier_status: str | None = None last_verifier_summary: str | None = None @@ -85,7 +85,7 @@ class TurnRunState(TypedDict): final_message: str | None turn_error: str | None completed: bool - token_usage_last: dict[str, Any] | None - token_usage_total: dict[str, Any] | None + token_usage_last: JsonObject | None + token_usage_total: JsonObject | None last_activity_at: float last_activity_method: str | None diff --git a/server/codex_bridge/operations.py b/server/codex_bridge/operations.py index ac6685c..4168b24 100644 --- a/server/codex_bridge/operations.py +++ b/server/codex_bridge/operations.py @@ -4,9 +4,10 @@ import copy import json -from typing import Any +from collections.abc import Sequence +from typing import cast -from shared.protocol import AgentPlan +from shared.protocol import AgentPlan, JsonObject from .apply_batch import prepare_apply_batch from .config import _TOOL_APPLY_OPERATIONS, _WHITE_BALANCE_ACTION_PATH_PREFIXES, logger @@ -17,11 +18,11 @@ class OperationsMixin: def _apply_operations_tool_call( self, context: TurnContext, - arguments: dict[str, Any], + arguments: JsonObject, *, thread_id: str | None = None, turn_id: str | None = None, - ) -> dict[str, Any]: + ) -> JsonObject: if not context.live_run_enabled: return self._tool_error_response( "apply_operations is only available when live run mode is enabled." @@ -70,10 +71,10 @@ def _apply_operations_tool_call( error=apply_error, ) return self._tool_error_response(apply_error) - applied_batch: list[dict[str, Any]] = [] + applied_batch: list[JsonObject] = [] step_summaries: list[str] = [] latest_preview_url: str | None = None - latest_verifier_result: dict[str, Any] | None = None + latest_verifier_result: JsonObject | None = None for step_index, operation in enumerate(ordered_batch, start=1): apply_error = self._apply_live_operation_step(context, operation) @@ -118,7 +119,7 @@ def _apply_operations_tool_call( success=True, ) - content_items: list[dict[str, Any]] = [ + content_items: list[JsonObject] = [ { "type": "inputText", "text": ( @@ -150,7 +151,7 @@ def _apply_operations_tool_call( def _apply_live_operation_step( self, context: TurnContext, - operation: dict[str, Any], + operation: JsonObject, ) -> str | None: apply_error, _ = self._apply_operation_to_settings( context.setting_by_id, operation @@ -174,7 +175,7 @@ def _apply_live_operation_step( def _wait_for_live_render( self, context: TurnContext - ) -> tuple[str | None, dict[str, Any] | None, str | None]: + ) -> tuple[str | None, JsonObject | None, str | None]: logger.info( "waiting_for_mid_turn_render", extra={ @@ -215,13 +216,15 @@ def _wait_for_live_render( def _summarize_live_operation( self, context: TurnContext, - operation: dict[str, Any], + operation: JsonObject, ) -> str: target = operation.get("target") - target_dict = target if isinstance(target, dict) else {} + target_dict: JsonObject = ( + cast(JsonObject, target) if isinstance(target, dict) else {} + ) action_path = str(target_dict.get("actionPath") or "unknown") setting_id = str(target_dict.get("settingId") or "") - setting = context.setting_by_id.get(setting_id, {}) + setting = context.setting_by_id.get(setting_id) or {} module_label = str(setting.get("moduleLabel") or "") control_label = str(setting.get("label") or action_path.rsplit("/", 1)[-1]) label = " / ".join(part for part in (module_label, control_label) if part) @@ -231,24 +234,25 @@ def _summarize_live_operation( value = operation.get("value") if not isinstance(value, dict): return label + value_dict = cast(JsonObject, value) kind = operation.get("kind") if kind == "set-float": - number = value.get("number") - mode = value.get("mode") + number = value_dict.get("number") + mode = value_dict.get("mode") if isinstance(number, (int, float)): if mode == "delta": return f"{label} {float(number):+0.3f}" return f"{label} = {float(number):0.3f}" if kind == "set-choice": - choice_id = value.get("choiceId") - choice_value = value.get("choiceValue") + choice_id = value_dict.get("choiceId") + choice_value = value_dict.get("choiceValue") if isinstance(choice_id, str) and choice_id: return f"{label} -> {choice_id}" if isinstance(choice_value, int): return f"{label} -> choice {choice_value}" if kind == "set-bool": - bool_value = value.get("boolValue") + bool_value = value_dict.get("boolValue") if isinstance(bool_value, bool): return f"{label} -> {'on' if bool_value else 'off'}" return label @@ -263,10 +267,10 @@ def _refresh_preview_after_operations(self, context: TurnContext) -> None: def _normalize_tool_operation( self, context: TurnContext, - raw_operation: dict[str, Any], + raw_operation: JsonObject, *, sequence_number: int, - ) -> tuple[dict[str, Any], str | None]: + ) -> tuple[JsonObject, str | None]: for key in ("kind", "target", "value"): if key not in raw_operation: return {}, f"operation is missing required member '{key}'" @@ -321,7 +325,7 @@ def _normalize_tool_operation( @staticmethod def _setting_ids_for_action_path( - setting_by_id: dict[str, dict[str, Any]], + setting_by_id: dict[str, JsonObject], action_path: str, ) -> list[str]: return [ @@ -331,7 +335,7 @@ def _setting_ids_for_action_path( ] @staticmethod - def _choice_mapping(setting: dict[str, Any]) -> dict[int, str]: + def _choice_mapping(setting: JsonObject) -> dict[int, str]: choices = setting.get("choices") mapping: dict[int, str] = {} if not isinstance(choices, list): @@ -339,8 +343,9 @@ def _choice_mapping(setting: dict[str, Any]) -> dict[int, str]: for choice in choices: if not isinstance(choice, dict): continue - value = choice.get("choiceValue") - choice_id = choice.get("choiceId") + choice_dict = cast(JsonObject, choice) + value = choice_dict.get("choiceValue") + choice_id = choice_dict.get("choiceId") if isinstance(value, int) and isinstance(choice_id, str) and choice_id: mapping[value] = choice_id return mapping @@ -353,11 +358,13 @@ def _is_white_balance_action_path(action_path: str) -> bool: ) @classmethod - def _white_balance_operation_rank( - cls, operation: dict[str, Any] - ) -> tuple[int, str]: + def _white_balance_operation_rank(cls, operation: JsonObject) -> tuple[int, str]: target = operation.get("target") - action_path = target.get("actionPath") if isinstance(target, dict) else None + action_path = ( + cast(JsonObject, target).get("actionPath") + if isinstance(target, dict) + else None + ) if not isinstance(action_path, str): return (99, "") leaf = action_path.rsplit("/", 1)[-1].lower() @@ -383,8 +390,8 @@ def _white_balance_operation_rank( return (channel_order.get(leaf, 99), leaf) def _order_operations_for_apply( - self, operations: list[dict[str, Any]] - ) -> list[dict[str, Any]]: + self, operations: list[JsonObject] + ) -> list[JsonObject]: ordered = list(operations) wb_indexes = [ index @@ -404,21 +411,22 @@ def _order_operations_for_apply( def _log_white_balance_tool_call( self, context: TurnContext, - attempted_operations: list[Any], - applied_operations: list[Any], + attempted_operations: Sequence[object], + applied_operations: Sequence[object], *, success: bool, error: str | None = None, ) -> None: - def _extract_paths(operations: list[Any]) -> list[str]: + def _extract_paths(operations: Sequence[object]) -> list[str]: paths: list[str] = [] for operation in operations: if not isinstance(operation, dict): continue - target = operation.get("target") + operation_dict = cast(JsonObject, operation) + target = operation_dict.get("target") if not isinstance(target, dict): continue - action_path = target.get("actionPath") + action_path = cast(JsonObject, target).get("actionPath") if isinstance(action_path, str) and self._is_white_balance_action_path( action_path ): @@ -447,15 +455,16 @@ def _extract_paths(operations: list[Any]) -> list[str]: def _apply_operation_to_settings( self, - setting_by_id: dict[str, dict[str, Any]], - operation: dict[str, Any], - ) -> tuple[str | None, dict[str, Any] | None]: + setting_by_id: dict[str, JsonObject], + operation: JsonObject, + ) -> tuple[str | None, JsonObject | None]: target = operation.get("target") if not isinstance(target, dict): return "operation target must be an object", None + target_dict = cast(JsonObject, target) - setting_id = target.get("settingId") - action_path = target.get("actionPath") + setting_id = target_dict.get("settingId") + action_path = target_dict.get("actionPath") if not isinstance(setting_id, str) or not isinstance(action_path, str): return "operation target requires settingId and actionPath", None @@ -475,8 +484,9 @@ def _apply_operation_to_settings( value = operation.get("value") if not isinstance(value, dict): return "operation value must be an object", None + value_dict = cast(JsonObject, value) - mode = value.get("mode") + mode = value_dict.get("mode") supported_modes = setting.get("supportedModes") if not isinstance(mode, str): return "operation value requires mode", None @@ -484,7 +494,7 @@ def _apply_operation_to_settings( return f"mode '{mode}' is not supported by settingId '{setting_id}'", None if kind == "set-float": - number_value = value.get("number") + number_value = value_dict.get("number") if not isinstance(number_value, (int, float)): return ( f"set-float operation requires numeric value.number for '{setting_id}'", @@ -521,7 +531,7 @@ def _apply_operation_to_settings( } if kind == "set-choice": - choice_value = value.get("choiceValue") + choice_value = value_dict.get("choiceValue") if not isinstance(choice_value, int): return ( f"set-choice operation requires integer value.choiceValue for '{setting_id}'", @@ -533,7 +543,7 @@ def _apply_operation_to_settings( f"choiceValue {choice_value} is not valid for '{setting_id}'", None, ) - choice_id = value.get("choiceId") + choice_id = value_dict.get("choiceId") if isinstance(choice_id, str) and choice_mapping.get(choice_value) not in { None, choice_id, @@ -557,7 +567,7 @@ def _apply_operation_to_settings( } if kind == "set-bool": - bool_value = value.get("boolValue") + bool_value = value_dict.get("boolValue") if not isinstance(bool_value, bool): return ( f"set-bool operation requires boolean value.boolValue for '{setting_id}'", diff --git a/server/codex_bridge/prompting.py b/server/codex_bridge/prompting.py index 331c079..e45ad95 100644 --- a/server/codex_bridge/prompting.py +++ b/server/codex_bridge/prompting.py @@ -5,9 +5,9 @@ import base64 import binascii import json -from typing import Any +from typing import cast -from shared.protocol import AgentPlan, RequestEnvelope +from shared.protocol import AgentPlan, JsonObject, RequestEnvelope from .config import _DEFAULT_HISTOGRAM_BINS, _DEFAULT_MAX_TOOL_CALLS_WITHOUT_APPLY from .errors import CodexAppServerError @@ -72,10 +72,12 @@ def _register_turn_context( preview_data_url: str, ) -> None: preview_mime_type, preview_bytes = self._decode_preview_image(request) - state_payload = json.loads(json.dumps(self._build_prompt_payload(request))) + state_payload = cast( + JsonObject, json.loads(json.dumps(self._build_prompt_payload(request))) + ) image_snapshot = state_payload.get("imageSnapshot", {}) editable_settings = image_snapshot.get("editableSettings", []) - setting_by_id: dict[str, dict[str, Any]] = {} + setting_by_id: dict[str, JsonObject] = {} base_float_setting_numbers: dict[str, float] = {} if isinstance(editable_settings, list): for setting in editable_settings: @@ -145,7 +147,7 @@ def _finalize_plan_with_live_context( } ) - normalized_operations: list[dict[str, Any]] = [] + normalized_operations: list[JsonObject] = [] seen_operation_ids: set[str] = set() for index, operation in enumerate(merged_operations, start=1): operation_copy = dict(operation) @@ -187,7 +189,7 @@ def _rebin(source_bins: list[int], target_count: int) -> list[int]: return rebinned @classmethod - def _trim_histogram_payload(cls, request: RequestEnvelope) -> dict[str, Any] | None: + def _trim_histogram_payload(cls, request: RequestEnvelope) -> JsonObject | None: histogram = request.imageSnapshot.histogram if histogram is None: return None @@ -211,10 +213,10 @@ def _trim_histogram_payload(cls, request: RequestEnvelope) -> dict[str, Any] | N "channels": trimmed_channels, } - def _build_prompt_payload(self, request: RequestEnvelope) -> dict[str, Any]: - compact_settings: list[dict[str, Any]] = [] + def _build_prompt_payload(self, request: RequestEnvelope) -> JsonObject: + compact_settings: list[JsonObject] = [] for setting in request.imageSnapshot.editableSettings: - compact_setting: dict[str, Any] = { + compact_setting: JsonObject = { "moduleId": setting.moduleId, "moduleLabel": setting.moduleLabel, "settingId": setting.settingId, @@ -243,7 +245,7 @@ def _build_prompt_payload(self, request: RequestEnvelope) -> dict[str, Any]: compact_settings.append(compact_setting) metadata = request.imageSnapshot.metadata - metadata_payload: dict[str, Any] = { + metadata_payload: JsonObject = { "width": metadata.width, "height": metadata.height, } @@ -288,8 +290,8 @@ def _build_turn_input( request: RequestEnvelope, *, preview_data_url: str | None = None, - ) -> list[dict[str, Any]]: - items: list[dict[str, Any]] = [] + ) -> list[JsonObject]: + items: list[JsonObject] = [] conv_id = request.session.conversationId history = getattr(self, "_conversation_histories", {}).get(conv_id) diff --git a/server/codex_bridge/request_state.py b/server/codex_bridge/request_state.py index e28b967..272f87b 100644 --- a/server/codex_bridge/request_state.py +++ b/server/codex_bridge/request_state.py @@ -2,9 +2,11 @@ # pyright: reportAttributeAccessIssue=false -from typing import Any +from typing import cast -from shared.protocol import AgentPlan +from pydantic import BaseModel + +from shared.protocol import AgentPlan, JsonObject from server.bridge_types import RequestProgressPayload from .config import logger @@ -12,27 +14,29 @@ from .models import ActiveRequestState, CancelRequestKey, TurnContext -def build_output_schema(agent_plan_type: Any) -> dict[str, Any]: +def build_output_schema(agent_plan_type: type[BaseModel]) -> JsonObject: schema = agent_plan_type.model_json_schema() - def _rewrite(node: Any) -> None: + def _rewrite(node: object) -> None: if isinstance(node, dict): - properties = node.get("properties") + node_dict = cast(JsonObject, node) + properties = node_dict.get("properties") if isinstance(properties, dict): - node["required"] = list(properties.keys()) - node.setdefault("additionalProperties", False) + node_dict["required"] = list(properties.keys()) + if "additionalProperties" not in node_dict: + node_dict["additionalProperties"] = False for child in properties.values(): _rewrite(child) for key in ("items", "anyOf", "allOf", "oneOf", "prefixItems"): - child = node.get(key) + child = node_dict.get(key) if isinstance(child, list): for item in child: _rewrite(item) elif isinstance(child, dict): _rewrite(child) - defs = node.get("$defs") + defs = node_dict.get("$defs") if isinstance(defs, dict): for child in defs.values(): _rewrite(child) @@ -46,7 +50,7 @@ def _rewrite(node: Any) -> None: class RequestStateMixin: @staticmethod - def _build_output_schema() -> dict[str, Any]: + def _build_output_schema() -> JsonObject: return build_output_schema(AgentPlan) def _register_request(self, request) -> ActiveRequestState: # type: ignore[no-untyped-def] @@ -165,7 +169,9 @@ def get_request_progress( "message": active_request.message, "lastToolName": active_request.last_tool_name, "progressVersion": active_request.progress_version, - "requiresRenderCallback": context.requires_render_callback if context else False, + "requiresRenderCallback": context.requires_render_callback + if context + else False, } def _is_cancelled(self, active_request: ActiveRequestState) -> bool: diff --git a/server/codex_bridge/tool_routing.py b/server/codex_bridge/tool_routing.py index 037ad68..6334f95 100644 --- a/server/codex_bridge/tool_routing.py +++ b/server/codex_bridge/tool_routing.py @@ -3,7 +3,9 @@ # pyright: reportAttributeAccessIssue=false import json -from typing import Any +from typing import cast + +from shared.protocol import JsonObject from .config import ( _DEFAULT_MAX_CONSECUTIVE_READ_ONLY_TOOL_CALLS, @@ -19,7 +21,7 @@ class ToolRoutingMixin: @staticmethod - def _dynamic_tools() -> list[dict[str, Any]]: + def _dynamic_tools() -> list[JsonObject]: empty_object_schema = { "type": "object", "properties": {}, @@ -75,7 +77,7 @@ def _dynamic_tools() -> list[dict[str, Any]]: }, ] - def _handle_server_request_locked(self, message: dict[str, Any]) -> None: + def _handle_server_request_locked(self, message: JsonObject) -> None: method = message.get("method") request_id = message.get("id") if request_id is None: @@ -117,9 +119,7 @@ def _handle_server_request_locked(self, message: dict[str, Any]) -> None: } ) - def _handle_dynamic_tool_call_locked( - self, message: dict[str, Any] - ) -> dict[str, Any]: + def _handle_dynamic_tool_call_locked(self, message: JsonObject) -> JsonObject: params = message.get("params", {}) thread_id = params.get("threadId") turn_id = params.get("turnId") @@ -226,7 +226,7 @@ def _handle_dynamic_tool_call_locked( for content_item in content_items: if not isinstance(content_item, dict): continue - text = content_item.get("text") + text = cast(JsonObject, content_item).get("text") if isinstance(text, str) and text: tool_error = text break @@ -372,7 +372,7 @@ def _register_tool_call_progress_locked( return None @staticmethod - def _tool_error_response(message: str) -> dict[str, Any]: + def _tool_error_response(message: str) -> JsonObject: return { "success": False, "contentItems": [{"type": "inputText", "text": message}], diff --git a/server/codex_bridge/transport.py b/server/codex_bridge/transport.py index e3ce7fc..c46cf9f 100644 --- a/server/codex_bridge/transport.py +++ b/server/codex_bridge/transport.py @@ -6,7 +6,9 @@ import select import subprocess import time -from typing import Any, cast +from typing import cast + +from shared.protocol import JsonObject from .config import _CLIENT_INFO, logger from .errors import CodexAppServerError @@ -87,10 +89,10 @@ def _reset_process_locked(self) -> None: def _send_request_locked( self, method: str, - params: Any, + params: object, deadline: float, active_request: ActiveRequestState | None, - ) -> dict[str, Any]: + ) -> JsonObject: request_id = self._next_request_id self._next_request_id += 1 self._send_json_locked( @@ -106,7 +108,9 @@ def _send_request_locked( if "error" in message: error = message["error"] error_message = ( - error.get("message") if isinstance(error, dict) else None + cast(JsonObject, error).get("message") + if isinstance(error, dict) + else None ) raise CodexAppServerError( "codex_jsonrpc_error", @@ -117,13 +121,15 @@ def _send_request_locked( return message self._handle_message_locked(message, None) - def _send_notification_locked(self, method: str, params: Any | None = None) -> None: - payload: dict[str, Any] = {"jsonrpc": "2.0", "method": method} + def _send_notification_locked( + self, method: str, params: object | None = None + ) -> None: + payload: JsonObject = {"jsonrpc": "2.0", "method": method} if params is not None: payload["params"] = params self._send_json_locked(payload) - def _send_json_locked(self, payload: dict[str, Any]) -> None: + def _send_json_locked(self, payload: JsonObject) -> None: if not self._process or not self._process.stdin: raise CodexAppServerError( "codex_process_unavailable", "Codex app server is not running" @@ -143,7 +149,7 @@ def _read_message_locked( active_request: ActiveRequestState | None = None, *, max_wait_seconds: float | None = None, - ) -> dict[str, Any] | None: + ) -> JsonObject | None: if not self._process or not self._process.stdout or not self._process.stderr: raise CodexAppServerError( "codex_process_unavailable", "Codex app server is not running" @@ -199,4 +205,4 @@ def _read_message_locked( "codex_invalid_json", f"Codex emitted non-object JSON: {line.rstrip()}", ) - return cast(dict[str, Any], payload) + return cast(JsonObject, payload) diff --git a/server/codex_bridge/turns.py b/server/codex_bridge/turns.py index fbe0753..d4ac362 100644 --- a/server/codex_bridge/turns.py +++ b/server/codex_bridge/turns.py @@ -3,9 +3,9 @@ # pyright: reportAttributeAccessIssue=false import time -from typing import Any +from typing import cast -from shared.protocol import AgentPlan, RequestEnvelope +from shared.protocol import AgentPlan, JsonObject, RequestEnvelope from .config import ( _DEFAULT_APPROVAL_POLICY, @@ -58,7 +58,7 @@ def _get_or_create_thread_locked( ) return existing - params: dict[str, Any] = { + params: JsonObject = { "cwd": self._cwd, "approvalPolicy": _DEFAULT_APPROVAL_POLICY, "sandbox": _DEFAULT_SANDBOX, @@ -277,7 +277,7 @@ def _run_turn_locked( active_request.codex_turn_id = None def _handle_message_locked( - self, message: dict[str, Any], turn_state: TurnRunState | None + self, message: JsonObject, turn_state: TurnRunState | None ) -> None: if "method" in message and "id" in message: self._handle_server_request_locked(message) @@ -287,7 +287,9 @@ def _handle_message_locked( method = message["method"] raw_params = message.get("params", {}) - params = raw_params if isinstance(raw_params, dict) else {} + params: JsonObject = ( + cast(JsonObject, raw_params) if isinstance(raw_params, dict) else {} + ) if method == "error": if ( @@ -296,7 +298,9 @@ def _handle_message_locked( and params.get("turnId") == turn_state["turn_id"] ): raw_error = params.get("error", {}) - error = raw_error if isinstance(raw_error, dict) else {} + error = ( + cast(JsonObject, raw_error) if isinstance(raw_error, dict) else {} + ) turn_state["turn_error"] = self._extract_error_message( error.get("message") or "Codex app server reported an error" ) @@ -321,12 +325,13 @@ def _handle_message_locked( return usage = params.get("tokenUsage", {}) if isinstance(usage, dict): - last_usage = usage.get("last") - total_usage = usage.get("total") + usage_dict = cast(JsonObject, usage) + last_usage = usage_dict.get("last") + total_usage = usage_dict.get("total") if isinstance(last_usage, dict): - turn_state["token_usage_last"] = last_usage + turn_state["token_usage_last"] = cast(JsonObject, last_usage) if isinstance(total_usage, dict): - turn_state["token_usage_total"] = total_usage + turn_state["token_usage_total"] = cast(JsonObject, total_usage) return if method == "item/completed": @@ -336,7 +341,7 @@ def _handle_message_locked( ): return raw_item = params.get("item", {}) - item = raw_item if isinstance(raw_item, dict) else {} + item = cast(JsonObject, raw_item) if isinstance(raw_item, dict) else {} if item.get("type") == "agentMessage": text = item.get("text") turn_state["final_message"] = text if isinstance(text, str) else None @@ -348,7 +353,7 @@ def _handle_message_locked( if params.get("id") != turn_state["turn_id"]: return raw_msg = params.get("msg", {}) - msg = raw_msg if isinstance(raw_msg, dict) else {} + msg = cast(JsonObject, raw_msg) if isinstance(raw_msg, dict) else {} last_agent_message = msg.get("last_agent_message") if isinstance(last_agent_message, str) and last_agent_message: turn_state["final_message"] = last_agent_message @@ -359,12 +364,13 @@ def _handle_message_locked( if params.get("threadId") != turn_state["thread_id"]: return raw_turn = params.get("turn", {}) - turn = raw_turn if isinstance(raw_turn, dict) else {} + turn = cast(JsonObject, raw_turn) if isinstance(raw_turn, dict) else {} if turn.get("id") != turn_state["turn_id"]: return raw_error = turn.get("error") if isinstance(raw_error, dict): + error_dict = cast(JsonObject, raw_error) turn_state["turn_error"] = self._extract_error_message( - raw_error.get("message") or "Codex turn failed" + error_dict.get("message") or "Codex turn failed" ) turn_state["completed"] = True diff --git a/server/codex_bridge/verifier.py b/server/codex_bridge/verifier.py index 533cb92..de9c0e3 100644 --- a/server/codex_bridge/verifier.py +++ b/server/codex_bridge/verifier.py @@ -2,7 +2,10 @@ import io import json -from typing import Any + +from typing import cast + +from shared.protocol import JsonObject from .config import logger from .models import TurnContext @@ -76,7 +79,7 @@ def _editing_profile(context: TurnContext) -> str: @staticmethod def _summed_deltas( - operations: list[dict[str, Any]], + operations: list[JsonObject], *, action_terms: tuple[str, ...], ) -> float: @@ -86,15 +89,17 @@ def _summed_deltas( value = operation.get("value") if not isinstance(target, dict) or not isinstance(value, dict): continue - action_path = str(target.get("actionPath") or "").lower() + target_dict = cast(JsonObject, target) + value_dict = cast(JsonObject, value) + action_path = str(target_dict.get("actionPath") or "").lower() if not any(term in action_path for term in action_terms): continue - number = value.get("number") + number = value_dict.get("number") if isinstance(number, (int, float)): total += float(number) return total - def _build_live_verifier_feedback(self, context: TurnContext) -> dict[str, Any]: + def _build_live_verifier_feedback(self, context: TurnContext) -> JsonObject: base_metrics = self._preview_metrics(context.base_preview_bytes) current_metrics = self._preview_metrics(context.current_preview_bytes) profile = self._editing_profile(context) @@ -114,7 +119,7 @@ def _build_live_verifier_feedback(self, context: TurnContext) -> dict[str, Any]: context.last_verifier_summary = summary return result - checks: list[dict[str, Any]] = [] + checks: list[JsonObject] = [] exposure_delta = self._summed_deltas( context.last_applied_batch, action_terms=("exposure", "filmic", "toneeq") ) @@ -182,11 +187,12 @@ def _build_live_verifier_feedback(self, context: TurnContext) -> dict[str, Any]: "detected after the latest live edits." ) else: - summary = "Verifier fail: " + " ".join( - check["detail"] + details = [ + detail for check in checks - if isinstance(check.get("detail"), str) - ) + if isinstance((detail := check.get("detail")), str) + ] + summary = "Verifier fail: " + " ".join(details) result = { "status": status, @@ -215,5 +221,5 @@ def _build_live_verifier_feedback(self, context: TurnContext) -> dict[str, Any]: return result @staticmethod - def _verifier_feedback_text(result: dict[str, Any]) -> str: + def _verifier_feedback_text(result: JsonObject) -> str: return "Verifier summary JSON:\n" + json.dumps(result, separators=(",", ":")) diff --git a/server/evals/harness.py b/server/evals/harness.py index 119d1be..6eb938e 100644 --- a/server/evals/harness.py +++ b/server/evals/harness.py @@ -7,9 +7,9 @@ import json from dataclasses import asdict from pathlib import Path -from typing import Any +from typing import cast -from shared.protocol import AgentPlan, EditableSetting +from shared.protocol import AgentPlan, EditableSetting, JsonObject from server.codex_bridge.canonical_binder import bind_canonical_actions from server.codex_bridge.verifier import VerifierMixin @@ -205,7 +205,7 @@ def main(argv: list[str] | None = None) -> int: def _evaluate_expectations( case: EvaluationCase, submission: EvaluationSubmission, - resolved_operations: list[dict[str, Any]], + resolved_operations: list[JsonObject], ) -> list[str]: failures: list[str] = [] expectation = case.expectations @@ -332,7 +332,7 @@ def _check_metric_threshold( def _validate_operations( - settings: list[EditableSetting], operations: list[dict[str, Any]] + settings: list[EditableSetting], operations: list[JsonObject] ) -> dict[str, int]: setting_by_id = {setting.settingId: setting for setting in settings} unknown_targets = 0 @@ -343,8 +343,10 @@ def _validate_operations( if not isinstance(target, dict) or not isinstance(value, dict): validation_failures += 1 continue - setting_id = target.get("settingId") - action_path = target.get("actionPath") + target_dict = cast(JsonObject, target) + value_dict = cast(JsonObject, value) + setting_id = target_dict.get("settingId") + action_path = target_dict.get("actionPath") if not isinstance(setting_id, str) or not isinstance(action_path, str): validation_failures += 1 continue @@ -352,16 +354,14 @@ def _validate_operations( if setting is None or setting.actionPath != action_path: unknown_targets += 1 continue - validation_failures += _operation_validation_failures(setting, value) + validation_failures += _operation_validation_failures(setting, value_dict) return { "unknown_targets": unknown_targets, "validation_failures": validation_failures, } -def _operation_validation_failures( - setting: EditableSetting, value: dict[str, Any] -) -> int: +def _operation_validation_failures(setting: EditableSetting, value: JsonObject) -> int: kind = setting.kind if kind == "set-float": return _float_validation_failures(setting, value) @@ -370,7 +370,7 @@ def _operation_validation_failures( return 0 -def _float_validation_failures(setting: EditableSetting, value: dict[str, Any]) -> int: +def _float_validation_failures(setting: EditableSetting, value: JsonObject) -> int: mode = value.get("mode") number = value.get("number") if mode not in {"set", "delta"} or not isinstance(number, (int, float)): @@ -385,7 +385,7 @@ def _float_validation_failures(setting: EditableSetting, value: dict[str, Any]) return 0 if minimum <= candidate <= maximum else 1 -def _choice_validation_failures(setting: EditableSetting, value: dict[str, Any]) -> int: +def _choice_validation_failures(setting: EditableSetting, value: JsonObject) -> int: if value.get("mode") != "set": return 1 choice_id = value.get("choiceId") diff --git a/shared/protocol.py b/shared/protocol.py index 1af6d26..fc30dfa 100644 --- a/shared/protocol.py +++ b/shared/protocol.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Literal, cast +from typing import Literal, cast from pydantic import BaseModel, ConfigDict, Field, model_validator @@ -10,6 +10,9 @@ SCHEMA_VERSION = "3.0" DEFAULT_REFINEMENT_MAX_PASSES = 15 +type JsonObject = dict[str, object] +type JsonArray = list[object] + OperationKind = Literal["set-float", "set-choice", "set-bool"] OperationMode = Literal["delta", "set"] RefinementMode = Literal["single-turn", "multi-turn"] @@ -544,7 +547,7 @@ def _build_refinement_status( ) -def parse_request_ids(payload: Any) -> tuple[str, dict[str, str]]: +def parse_request_ids(payload: object) -> tuple[str, dict[str, str]]: if not isinstance(payload, dict): return "", { "appSessionId": "",