Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import json
import logging
import os
from typing import Any

from fastapi import FastAPI, Request
from fastapi.exceptions import RequestValidationError
Expand All @@ -24,6 +23,7 @@
ResponseEnvelope,
ResponseSession,
AssistantMessage,
JsonObject,
build_response_from_plan,
parse_request_ids,
)
Expand All @@ -34,7 +34,7 @@

class JsonFormatter(logging.Formatter):
def format(self, record: logging.LogRecord) -> str:
payload: dict[str, Any] = {
payload: JsonObject = {
"level": record.levelname,
"logger": record.name,
"message": record.getMessage(),
Expand Down Expand Up @@ -122,7 +122,7 @@ def build_error_payload(
refinement: RefinementStatus | None,
code: str,
message: str,
) -> dict[str, Any]:
) -> JsonObject:
payload = ResponseEnvelope(
requestId=request_id,
session=ResponseSession.model_validate(session),
Expand Down Expand Up @@ -197,7 +197,7 @@ def _log_fulfilled_request(
)


def _encode_sse(event: str, payload: Mapping[str, Any]) -> str:
def _encode_sse(event: str, payload: Mapping[str, object]) -> str:
return f"event: {event}\ndata: {json.dumps(payload, separators=(',', ':'))}\n\n"


Expand Down
6 changes: 3 additions & 3 deletions server/bridge_types.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

from typing import Any, Protocol, TypedDict
from typing import Protocol, TypedDict

from shared.protocol import AgentPlan, RequestEnvelope
from shared.protocol import AgentPlan, JsonObject, RequestEnvelope


class RequestProgressPayload(TypedDict):
Expand All @@ -11,7 +11,7 @@ class RequestProgressPayload(TypedDict):
toolCallsUsed: int
maxToolCalls: int
appliedOperationCount: int
operations: list[dict[str, Any]]
operations: list[JsonObject]
message: str
lastToolName: str | None
progressVersion: int
Expand Down
26 changes: 11 additions & 15 deletions server/codex_bridge/apply_batch.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,26 @@
from __future__ import annotations

from collections.abc import Callable
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from typing import Any

from typing import cast
from shared.canonical_plan import CanonicalEditAction
from shared.protocol import JsonObject

from .canonical_binder import bind_canonical_actions
from .models import TurnContext


@dataclass(frozen=True, slots=True)
class PreparedApplyBatch:
normalized_batch: list[dict[str, Any]]
normalized_batch: list[JsonObject]
render_warnings: list[str]


def prepare_apply_batch(
context: TurnContext,
arguments: dict[str, Any],
arguments: JsonObject,
*,
normalize_operation: Callable[
[dict[str, Any], int], tuple[dict[str, Any], str | None]
],
normalize_operation: Callable[[JsonObject, int], tuple[JsonObject, str | None]],
) -> tuple[PreparedApplyBatch | None, str | None]:
raw_operations = arguments.get("operations")
raw_canonical_actions = arguments.get("canonicalActions")
Expand All @@ -48,7 +46,7 @@ def prepare_apply_batch(


def _prepare_canonical_batch(
context: TurnContext, raw_canonical_actions: list[Any]
context: TurnContext, raw_canonical_actions: Sequence[object]
) -> tuple[PreparedApplyBatch | None, str | None]:
canonical_actions: list[CanonicalEditAction] = []
for raw_action in raw_canonical_actions:
Expand Down Expand Up @@ -76,17 +74,15 @@ def _prepare_canonical_batch(

def _prepare_raw_batch(
context: TurnContext,
raw_operations: list[Any],
normalize_operation: Callable[
[dict[str, Any], int], tuple[dict[str, Any], str | None]
],
raw_operations: Sequence[object],
normalize_operation: Callable[[JsonObject, int], tuple[JsonObject, str | None]],
) -> tuple[PreparedApplyBatch | None, str | None]:
normalized_batch: list[dict[str, Any]] = []
normalized_batch: list[JsonObject] = []
for index, raw_operation in enumerate(raw_operations):
if not isinstance(raw_operation, dict):
return None, "Every apply_operations entry must be an object."
normalized_operation, error = normalize_operation(
raw_operation,
cast(JsonObject, raw_operation),
context.next_operation_sequence + index,
)
if error:
Expand Down
42 changes: 30 additions & 12 deletions server/codex_bridge/image_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import base64
import binascii
import io
from typing import Any, Literal
from typing import Literal, TypedDict

from shared.analysis_signals import (
ActiveModuleSignal,
Expand All @@ -12,7 +12,25 @@
RegionSignalSummary,
TonalSignalSummary,
)
from shared.protocol import RequestEnvelope
from shared.protocol import JsonObject, RequestEnvelope


class PreviewPixel(TypedDict):
red: float
green: float
blue: float
luma: float
saturation: float


class PreviewSamples(TypedDict):
width: int
height: int
samples: list[PreviewPixel]
lumas: list[float]
saturations: list[float]
grayscale: list[float]


try:
from PIL import Image
Expand Down Expand Up @@ -41,7 +59,7 @@ def _decode_preview_bytes(request: RequestEnvelope) -> bytes | None:
return None


def _preview_samples(image_bytes: bytes) -> dict[str, Any] | None:
def _preview_samples(image_bytes: bytes) -> PreviewSamples | None:
if Image is None or not image_bytes:
return None

Expand All @@ -60,7 +78,7 @@ def _preview_samples(image_bytes: bytes) -> dict[str, Any] | None:
grayscale: list[float] = []
lumas: list[float] = []
saturations: list[float] = []
samples: list[dict[str, float]] = []
samples: list[PreviewPixel] = []
for red, green, blue in pixels:
luma = (0.2126 * red + 0.7152 * green + 0.0722 * blue) / 255.0
saturation = (max(red, green, blue) - min(red, green, blue)) / 255.0
Expand All @@ -87,7 +105,7 @@ def _preview_samples(image_bytes: bytes) -> dict[str, Any] | None:
}


def _tonal_from_preview(preview_samples: dict[str, Any]) -> TonalSignalSummary:
def _tonal_from_preview(preview_samples: PreviewSamples) -> TonalSignalSummary:
lumas = sorted(float(value) for value in preview_samples["lumas"])
samples = preview_samples["samples"]
total = max(1, len(samples))
Expand Down Expand Up @@ -146,7 +164,7 @@ def _tonal_from_histogram(request: RequestEnvelope) -> TonalSignalSummary | None


def _sharpness_estimate(
preview_samples: dict[str, Any] | None,
preview_samples: PreviewSamples | None,
) -> Literal["unknown", "soft", "normal", "crisp"]:
if preview_samples is None:
return "unknown"
Expand Down Expand Up @@ -190,29 +208,29 @@ def _noise_risk(


def _region_slice(
preview_samples: dict[str, Any],
preview_samples: PreviewSamples,
*,
x_start: float,
x_end: float,
y_start: float,
y_end: float,
) -> list[dict[str, float]]:
) -> list[PreviewPixel]:
width = int(preview_samples["width"])
height = int(preview_samples["height"])
samples = preview_samples["samples"]
left = max(0, min(width - 1, int(width * x_start)))
right = max(left + 1, min(width, int(width * x_end)))
top = max(0, min(height - 1, int(height * y_start)))
bottom = max(top + 1, min(height, int(height * y_end)))
region: list[dict[str, float]] = []
region: list[PreviewPixel] = []
for y in range(top, bottom):
start = y * width + left
end = y * width + right
region.extend(samples[start:end])
return region


def _region_stats(region: list[dict[str, float]]) -> tuple[float, float]:
def _region_stats(region: list[PreviewPixel]) -> tuple[float, float]:
if not region:
return 0.0, 0.0
mean_luma = sum(sample["luma"] for sample in region) / len(region)
Expand All @@ -221,7 +239,7 @@ def _region_stats(region: list[dict[str, float]]) -> tuple[float, float]:


def _region_summaries(
preview_samples: dict[str, Any] | None,
preview_samples: PreviewSamples | None,
) -> list[RegionSignalSummary]:
if preview_samples is None:
return []
Expand Down Expand Up @@ -303,7 +321,7 @@ def _active_modules(request: RequestEnvelope) -> tuple[int, list[ActiveModuleSig
return len(active_history), signals


def build_image_analysis_signals(request: RequestEnvelope) -> dict[str, Any]:
def build_image_analysis_signals(request: RequestEnvelope) -> JsonObject:
preview_samples = _preview_samples(_decode_preview_bytes(request) or b"")
tonal = _tonal_from_preview(preview_samples) if preview_samples else None
if tonal is None:
Expand Down
16 changes: 8 additions & 8 deletions server/codex_bridge/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import threading
from dataclasses import dataclass, field
from typing import Any, TypedDict
from typing import TypedDict

from shared.protocol import AgentPlan, RequestEnvelope
from shared.protocol import AgentPlan, JsonObject, RequestEnvelope


@dataclass(frozen=True, slots=True)
Expand Down Expand Up @@ -60,19 +60,19 @@ class TurnContext:
current_preview_bytes: bytes
preview_mime_type: str
base_image_revision_id: str
state_payload: dict[str, Any]
setting_by_id: dict[str, dict[str, Any]]
state_payload: JsonObject
setting_by_id: dict[str, JsonObject]
base_float_setting_numbers: dict[str, float]
live_run_enabled: bool
max_tool_calls: int
tool_calls_used: int = 0
consecutive_read_only_tool_calls: int = 0
applied_operations: list[dict[str, Any]] = field(default_factory=list)
applied_operations: list[JsonObject] = field(default_factory=list)
next_operation_sequence: int = 1
render_event: threading.Event = field(default_factory=threading.Event)
rendered_preview_bytes: bytes | None = None
requires_render_callback: bool = False
last_applied_batch: list[dict[str, Any]] = field(default_factory=list)
last_applied_batch: list[JsonObject] = field(default_factory=list)
last_applied_summary: str | None = None
last_verifier_status: str | None = None
last_verifier_summary: str | None = None
Expand All @@ -85,7 +85,7 @@ class TurnRunState(TypedDict):
final_message: str | None
turn_error: str | None
completed: bool
token_usage_last: dict[str, Any] | None
token_usage_total: dict[str, Any] | None
token_usage_last: JsonObject | None
token_usage_total: JsonObject | None
last_activity_at: float
last_activity_method: str | None
Loading