Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ version = "0.1.0"
description = "Codex App auth-split proxy for signing in with ChatGPT while using third-party OpenAI-compatible APIs."
readme = "README.md"
requires-python = ">=3.11"
dependencies = []
dependencies = [
"zstandard>=0.25.0",
]
license = { text = "MIT" }
keywords = [
"agent-skills",
Expand Down
104 changes: 102 additions & 2 deletions src/codex_fast_proxy/proxy.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from __future__ import annotations

import argparse
import gzip
import hashlib
import http.client
import json
import os
import signal
import shutil
import ssl
import subprocess
import sys
import threading
import time
Expand All @@ -17,6 +20,11 @@
from typing import Any, Callable, Iterable
from urllib.parse import urlsplit

try:
import zstandard
except ImportError: # pragma: no cover - optional local dependency
zstandard = None

from . import __version__
from .auth import resolve_env
from .dashboard import DASHBOARD_PATH, render_dashboard
Expand Down Expand Up @@ -156,13 +164,52 @@ def copy_response_headers(headers: Iterable[tuple[str, str]], chunked: bool) ->
return copied


def decompress_zstd(body: bytes) -> bytes:
python_error: Exception | None = None

if zstandard is not None:
try:
with zstandard.ZstdDecompressor().stream_reader(body) as reader:
return reader.read()
except Exception as exc:
python_error = exc

zstd_bin = shutil.which("zstd")
if zstd_bin:
try:
result = subprocess.run(
[zstd_bin, "-dc"],
input=body,
stdout=subprocess.PIPE,
stderr=subprocess.DEVNULL,
check=True,
)
return result.stdout
except Exception as exc:
if python_error is not None:
raise python_error from exc
raise

if python_error is not None:
raise python_error
return body


def compress_zstd(body: bytes) -> bytes:
if zstandard is None:
return body

return zstandard.ZstdCompressor().compress(body)


def service_tier_patch(
method: str,
raw_path: str,
body: bytes,
content_type: str,
service_tier: str,
service_tier_policy: str = "inject_missing",
content_encoding: str = "",
) -> tuple[bytes, dict[str, Any]]:
event = {
"eligible": False,
Expand All @@ -171,6 +218,9 @@ def service_tier_patch(
"service_tier_after": None,
"stream": None,
"json_error": None,
"request_body_len": len(body),
"request_body_magic": body[:4].hex() if body else "",
"request_content_encoding": content_encoding,
}

if method.upper() != "POST" or normalized_path(raw_path) != RESPONSES_PATH:
Expand All @@ -181,7 +231,14 @@ def service_tier_patch(
return body, event

try:
payload = json.loads(body.decode("utf-8"))
normalized_encoding = content_encoding.lower().strip()
if normalized_encoding == "gzip":
decoded_body = gzip.decompress(body)
elif normalized_encoding == "zstd" and zstandard is not None:
decoded_body = decompress_zstd(body)
else:
decoded_body = body
payload = json.loads(decoded_body.decode("utf-8"))
except Exception as exc:
event["json_error"] = type(exc).__name__
return body, event
Expand All @@ -200,7 +257,13 @@ def service_tier_patch(
event["service_tier_after"] = payload.get("service_tier", "<absent>")
if not event["injected"]:
return body, event
return compact_json(payload).encode("utf-8"), event

patched_body = compact_json(payload).encode("utf-8")
if normalized_encoding == "gzip":
return gzip.compress(patched_body), event
if normalized_encoding == "zstd" and zstandard is not None:
return compress_zstd(patched_body), event
return patched_body, event


def write_chunk(writer: Any, data: bytes) -> None:
Expand Down Expand Up @@ -240,6 +303,35 @@ def stream_response_body(
writer.flush()


def read_chunked_request_body(reader: Any) -> bytes:
chunks: list[bytes] = []
while True:
size_line = reader.readline()
if not size_line:
raise ValueError("incomplete_chunked_request_body")

size_text = size_line.split(b";", 1)[0].strip()
try:
size = int(size_text, 16)
except ValueError as exc:
raise ValueError("invalid_chunked_request_body") from exc

if size == 0:
while True:
trailer_line = reader.readline()
if trailer_line in {b"", b"\n", b"\r\n"}:
return b"".join(chunks)

chunk = reader.read(size)
if len(chunk) != size:
raise ValueError("incomplete_chunked_request_body")
chunks.append(chunk)

line_end = reader.readline()
if line_end not in {b"\n", b"\r\n"}:
raise ValueError("invalid_chunked_request_body")


class FastProxyHandler(BaseHTTPRequestHandler):
server_version = "CodexFastProxy/0.1"
protocol_version = "HTTP/1.1"
Expand Down Expand Up @@ -285,6 +377,7 @@ def proxy(self) -> None:
self.headers.get("Content-Type", ""),
self.server.service_tier,
getattr(self.server, "service_tier_effective_policy", "inject_missing"),
self.headers.get("Content-Encoding", ""),
)

upstream_path = upstream_request_path(
Expand Down Expand Up @@ -356,6 +449,10 @@ def respond_dashboard(self) -> None:
self.wfile.flush()

def read_request_body(self) -> bytes:
transfer_encoding = self.headers.get("Transfer-Encoding", "")
if "chunked" in transfer_encoding.lower():
return read_chunked_request_body(self.rfile)

length = int(self.headers.get("Content-Length", "0") or "0")
return self.rfile.read(length) if length > 0 else b""

Expand Down Expand Up @@ -423,6 +520,9 @@ def write_event(
"service_tier_effective_policy": getattr(self.server, "service_tier_effective_policy", "inject_missing"),
"stream": patch_event["stream"],
"json_error": patch_event["json_error"],
"request_body_len": patch_event.get("request_body_len"),
"request_body_magic": patch_event.get("request_body_magic"),
"request_content_encoding": patch_event.get("request_content_encoding"),
"response_content_type": response_content_type,
"error_type": error_type,
}
Expand Down
141 changes: 141 additions & 0 deletions tests/test_proxy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import gzip
import json
import shutil
import sys
Expand Down Expand Up @@ -27,10 +28,13 @@
copy_request_headers,
dashboard_requested,
parse_args,
read_chunked_request_body,
decompress_zstd,
runtime_details,
service_tier_patch,
stream_response_body,
upstream_request_path,
zstandard,
)


Expand Down Expand Up @@ -136,6 +140,91 @@ def test_preserve_policy_does_not_inject_missing_service_tier(self) -> None:
self.assertEqual(event["service_tier_before"], "<absent>")
self.assertEqual(event["service_tier_after"], "<absent>")

def test_injects_missing_service_tier_in_gzip_json_body(self) -> None:
payload = {"model": "gpt-5.4", "stream": True}
raw_body = gzip.compress(json.dumps(payload).encode("utf-8"))
body, event = service_tier_patch(
"POST",
"/v1/responses",
raw_body,
"application/json",
"priority",
"inject_missing",
"gzip",
)

patched = json.loads(gzip.decompress(body))
self.assertEqual(patched["service_tier"], "priority")
self.assertTrue(event["injected"])
self.assertEqual(event["json_error"], None)

@unittest.skipIf(zstandard is None, "zstandard package is not installed")
def test_injects_missing_service_tier_in_zstd_json_body(self) -> None:
payload = {"model": "gpt-5.4", "stream": True}
raw_body = zstandard.ZstdCompressor().compress(json.dumps(payload).encode("utf-8"))
body, event = service_tier_patch(
"POST",
"/v1/responses",
raw_body,
"application/json",
"priority",
"inject_missing",
"zstd",
)

patched = json.loads(zstandard.ZstdDecompressor().decompress(body))
self.assertEqual(patched["service_tier"], "priority")
self.assertTrue(event["injected"])
self.assertEqual(event["json_error"], None)

@unittest.skipIf(zstandard is None, "zstandard package is not installed")
def test_injects_missing_service_tier_in_zstd_body_without_content_size(self) -> None:
payload = {"model": "gpt-5.4", "stream": True}
raw_body = zstandard.ZstdCompressor(write_content_size=False).compress(json.dumps(payload).encode("utf-8"))
body, event = service_tier_patch(
"POST",
"/v1/responses",
raw_body,
"application/json",
"priority",
"inject_missing",
"zstd",
)

patched = json.loads(decompress_zstd(body))
self.assertEqual(patched["service_tier"], "priority")
self.assertTrue(event["injected"])
self.assertEqual(event["json_error"], None)

@unittest.skipIf(zstandard is None or shutil.which("zstd") is None, "zstd fallback is not available")
def test_injects_missing_service_tier_when_python_zstd_stream_reader_fails(self) -> None:
payload = {"model": "gpt-5.4", "stream": True}
raw_body = zstandard.ZstdCompressor(write_content_size=False).compress(json.dumps(payload).encode("utf-8"))
original = zstandard.ZstdDecompressor

class BrokenDecompressor:
def stream_reader(self, _body: bytes) -> Any:
raise zstandard.ZstdError("forced stream failure")

try:
zstandard.ZstdDecompressor = BrokenDecompressor
body, event = service_tier_patch(
"POST",
"/v1/responses",
raw_body,
"application/json",
"priority",
"inject_missing",
"zstd",
)
finally:
zstandard.ZstdDecompressor = original

patched = json.loads(decompress_zstd(body))
self.assertEqual(patched["service_tier"], "priority")
self.assertTrue(event["injected"])
self.assertEqual(event["json_error"], None)

def test_leaves_non_responses_paths_untouched(self) -> None:
body = b'{"model":"gpt-5.4"}'
patched, event = service_tier_patch("POST", "/v1/chat/completions", body, "application/json", "priority")
Expand Down Expand Up @@ -235,6 +324,58 @@ def test_sse_payload_bytes_are_forwarded_without_event_rewrite(self) -> None:

self.assertEqual(writer.getvalue(), b"event: response.output_text.delta\ndata: {\"x\":1}\n\n")

def test_chunked_request_body_is_dechunked_before_fast_injection(self) -> None:
temp_root = ROOT / ".test_tmp"
temp_root.mkdir(exist_ok=True)
temp_dir = temp_root / f"chunked-inject-{uuid.uuid4().hex}"
temp_dir.mkdir()
try:
log_path = temp_dir / "fast_proxy.jsonl"
raw_body = json.dumps({"model": "gpt-test", "stream": True}).encode("utf-8")
chunked_body = b"%X\r\n%s\r\n0\r\n\r\n" % (len(raw_body), raw_body)
connection = FakeConnection()
handler = FastProxyHandler.__new__(FastProxyHandler)
handler.command = "POST"
handler.path = "/v1/responses"
handler.headers = {
"Content-Type": "application/json",
"Transfer-Encoding": "chunked",
}
handler.rfile = BytesIO(chunked_body)
handler.server = SimpleNamespace(
proxy_base="/v1",
upstream_base_path="/v1",
upstream_netloc="api.example.test",
service_tier="priority",
service_tier_policy="inject_missing",
service_tier_effective_policy="inject_missing",
log_path=log_path,
log_lock=threading.Lock(),
verbose=False,
open_connection=lambda: connection,
)
handler.forward_response = lambda _response: None
handler.respond_bad_gateway = lambda: None

FastProxyHandler.proxy(handler)
event = json.loads(log_path.read_text(encoding="utf-8"))
finally:
shutil.rmtree(temp_dir)

forwarded = json.loads(connection.body)
self.assertEqual(forwarded["service_tier"], "priority")
self.assertEqual(connection.headers["Content-Length"], str(len(connection.body)))
self.assertNotIn("Transfer-Encoding", connection.headers)
self.assertTrue(event["service_tier_injected"])
self.assertEqual(event["service_tier_before"], "<absent>")
self.assertEqual(event["service_tier_after"], "priority")
self.assertIsNone(event["json_error"])

def test_read_chunked_request_body_handles_extensions_and_trailers(self) -> None:
body = read_chunked_request_body(BytesIO(b"5;foo=bar\r\nhello\r\n0\r\nx-test: ok\r\n\r\n"))

self.assertEqual(body, b"hello")

def test_client_disconnect_during_stream_is_logged_without_bad_gateway(self) -> None:
temp_root = ROOT / ".test_tmp"
temp_root.mkdir(exist_ok=True)
Expand Down