diff --git a/pyproject.toml b/pyproject.toml index 32acaa2..a49fb1d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,9 @@ version = "0.1.0" description = "Codex App auth-split proxy for signing in with ChatGPT while using third-party OpenAI-compatible APIs." readme = "README.md" requires-python = ">=3.11" -dependencies = [] +dependencies = [ + "zstandard>=0.25.0", +] license = { text = "MIT" } keywords = [ "agent-skills", diff --git a/src/codex_fast_proxy/proxy.py b/src/codex_fast_proxy/proxy.py index 640a657..47ed998 100644 --- a/src/codex_fast_proxy/proxy.py +++ b/src/codex_fast_proxy/proxy.py @@ -1,12 +1,15 @@ from __future__ import annotations import argparse +import gzip import hashlib import http.client import json import os import signal +import shutil import ssl +import subprocess import sys import threading import time @@ -17,6 +20,11 @@ from typing import Any, Callable, Iterable from urllib.parse import urlsplit +try: + import zstandard +except ImportError: # pragma: no cover - optional local dependency + zstandard = None + from . import __version__ from .auth import resolve_env from .dashboard import DASHBOARD_PATH, render_dashboard @@ -156,6 +164,44 @@ def copy_response_headers(headers: Iterable[tuple[str, str]], chunked: bool) -> return copied +def decompress_zstd(body: bytes) -> bytes: + python_error: Exception | None = None + + if zstandard is not None: + try: + with zstandard.ZstdDecompressor().stream_reader(body) as reader: + return reader.read() + except Exception as exc: + python_error = exc + + zstd_bin = shutil.which("zstd") + if zstd_bin: + try: + result = subprocess.run( + [zstd_bin, "-dc"], + input=body, + stdout=subprocess.PIPE, + stderr=subprocess.DEVNULL, + check=True, + ) + return result.stdout + except Exception as exc: + if python_error is not None: + raise python_error from exc + raise + + if python_error is not None: + raise python_error + return body + + +def compress_zstd(body: bytes) -> bytes: + if zstandard is None: + return body + + return zstandard.ZstdCompressor().compress(body) + + def service_tier_patch( method: str, raw_path: str, @@ -163,6 +209,7 @@ def service_tier_patch( content_type: str, service_tier: str, service_tier_policy: str = "inject_missing", + content_encoding: str = "", ) -> tuple[bytes, dict[str, Any]]: event = { "eligible": False, @@ -171,6 +218,9 @@ def service_tier_patch( "service_tier_after": None, "stream": None, "json_error": None, + "request_body_len": len(body), + "request_body_magic": body[:4].hex() if body else "", + "request_content_encoding": content_encoding, } if method.upper() != "POST" or normalized_path(raw_path) != RESPONSES_PATH: @@ -181,7 +231,14 @@ def service_tier_patch( return body, event try: - payload = json.loads(body.decode("utf-8")) + normalized_encoding = content_encoding.lower().strip() + if normalized_encoding == "gzip": + decoded_body = gzip.decompress(body) + elif normalized_encoding == "zstd" and zstandard is not None: + decoded_body = decompress_zstd(body) + else: + decoded_body = body + payload = json.loads(decoded_body.decode("utf-8")) except Exception as exc: event["json_error"] = type(exc).__name__ return body, event @@ -200,7 +257,13 @@ def service_tier_patch( event["service_tier_after"] = payload.get("service_tier", "") if not event["injected"]: return body, event - return compact_json(payload).encode("utf-8"), event + + patched_body = compact_json(payload).encode("utf-8") + if normalized_encoding == "gzip": + return gzip.compress(patched_body), event + if normalized_encoding == "zstd" and zstandard is not None: + return compress_zstd(patched_body), event + return patched_body, event def write_chunk(writer: Any, data: bytes) -> None: @@ -240,6 +303,35 @@ def stream_response_body( writer.flush() +def read_chunked_request_body(reader: Any) -> bytes: + chunks: list[bytes] = [] + while True: + size_line = reader.readline() + if not size_line: + raise ValueError("incomplete_chunked_request_body") + + size_text = size_line.split(b";", 1)[0].strip() + try: + size = int(size_text, 16) + except ValueError as exc: + raise ValueError("invalid_chunked_request_body") from exc + + if size == 0: + while True: + trailer_line = reader.readline() + if trailer_line in {b"", b"\n", b"\r\n"}: + return b"".join(chunks) + + chunk = reader.read(size) + if len(chunk) != size: + raise ValueError("incomplete_chunked_request_body") + chunks.append(chunk) + + line_end = reader.readline() + if line_end not in {b"\n", b"\r\n"}: + raise ValueError("invalid_chunked_request_body") + + class FastProxyHandler(BaseHTTPRequestHandler): server_version = "CodexFastProxy/0.1" protocol_version = "HTTP/1.1" @@ -285,6 +377,7 @@ def proxy(self) -> None: self.headers.get("Content-Type", ""), self.server.service_tier, getattr(self.server, "service_tier_effective_policy", "inject_missing"), + self.headers.get("Content-Encoding", ""), ) upstream_path = upstream_request_path( @@ -356,6 +449,10 @@ def respond_dashboard(self) -> None: self.wfile.flush() def read_request_body(self) -> bytes: + transfer_encoding = self.headers.get("Transfer-Encoding", "") + if "chunked" in transfer_encoding.lower(): + return read_chunked_request_body(self.rfile) + length = int(self.headers.get("Content-Length", "0") or "0") return self.rfile.read(length) if length > 0 else b"" @@ -423,6 +520,9 @@ def write_event( "service_tier_effective_policy": getattr(self.server, "service_tier_effective_policy", "inject_missing"), "stream": patch_event["stream"], "json_error": patch_event["json_error"], + "request_body_len": patch_event.get("request_body_len"), + "request_body_magic": patch_event.get("request_body_magic"), + "request_content_encoding": patch_event.get("request_content_encoding"), "response_content_type": response_content_type, "error_type": error_type, } diff --git a/tests/test_proxy.py b/tests/test_proxy.py index 0f304a7..ce9ae64 100644 --- a/tests/test_proxy.py +++ b/tests/test_proxy.py @@ -1,5 +1,6 @@ from __future__ import annotations +import gzip import json import shutil import sys @@ -27,10 +28,13 @@ copy_request_headers, dashboard_requested, parse_args, + read_chunked_request_body, + decompress_zstd, runtime_details, service_tier_patch, stream_response_body, upstream_request_path, + zstandard, ) @@ -136,6 +140,91 @@ def test_preserve_policy_does_not_inject_missing_service_tier(self) -> None: self.assertEqual(event["service_tier_before"], "") self.assertEqual(event["service_tier_after"], "") + def test_injects_missing_service_tier_in_gzip_json_body(self) -> None: + payload = {"model": "gpt-5.4", "stream": True} + raw_body = gzip.compress(json.dumps(payload).encode("utf-8")) + body, event = service_tier_patch( + "POST", + "/v1/responses", + raw_body, + "application/json", + "priority", + "inject_missing", + "gzip", + ) + + patched = json.loads(gzip.decompress(body)) + self.assertEqual(patched["service_tier"], "priority") + self.assertTrue(event["injected"]) + self.assertEqual(event["json_error"], None) + + @unittest.skipIf(zstandard is None, "zstandard package is not installed") + def test_injects_missing_service_tier_in_zstd_json_body(self) -> None: + payload = {"model": "gpt-5.4", "stream": True} + raw_body = zstandard.ZstdCompressor().compress(json.dumps(payload).encode("utf-8")) + body, event = service_tier_patch( + "POST", + "/v1/responses", + raw_body, + "application/json", + "priority", + "inject_missing", + "zstd", + ) + + patched = json.loads(zstandard.ZstdDecompressor().decompress(body)) + self.assertEqual(patched["service_tier"], "priority") + self.assertTrue(event["injected"]) + self.assertEqual(event["json_error"], None) + + @unittest.skipIf(zstandard is None, "zstandard package is not installed") + def test_injects_missing_service_tier_in_zstd_body_without_content_size(self) -> None: + payload = {"model": "gpt-5.4", "stream": True} + raw_body = zstandard.ZstdCompressor(write_content_size=False).compress(json.dumps(payload).encode("utf-8")) + body, event = service_tier_patch( + "POST", + "/v1/responses", + raw_body, + "application/json", + "priority", + "inject_missing", + "zstd", + ) + + patched = json.loads(decompress_zstd(body)) + self.assertEqual(patched["service_tier"], "priority") + self.assertTrue(event["injected"]) + self.assertEqual(event["json_error"], None) + + @unittest.skipIf(zstandard is None or shutil.which("zstd") is None, "zstd fallback is not available") + def test_injects_missing_service_tier_when_python_zstd_stream_reader_fails(self) -> None: + payload = {"model": "gpt-5.4", "stream": True} + raw_body = zstandard.ZstdCompressor(write_content_size=False).compress(json.dumps(payload).encode("utf-8")) + original = zstandard.ZstdDecompressor + + class BrokenDecompressor: + def stream_reader(self, _body: bytes) -> Any: + raise zstandard.ZstdError("forced stream failure") + + try: + zstandard.ZstdDecompressor = BrokenDecompressor + body, event = service_tier_patch( + "POST", + "/v1/responses", + raw_body, + "application/json", + "priority", + "inject_missing", + "zstd", + ) + finally: + zstandard.ZstdDecompressor = original + + patched = json.loads(decompress_zstd(body)) + self.assertEqual(patched["service_tier"], "priority") + self.assertTrue(event["injected"]) + self.assertEqual(event["json_error"], None) + def test_leaves_non_responses_paths_untouched(self) -> None: body = b'{"model":"gpt-5.4"}' patched, event = service_tier_patch("POST", "/v1/chat/completions", body, "application/json", "priority") @@ -235,6 +324,58 @@ def test_sse_payload_bytes_are_forwarded_without_event_rewrite(self) -> None: self.assertEqual(writer.getvalue(), b"event: response.output_text.delta\ndata: {\"x\":1}\n\n") + def test_chunked_request_body_is_dechunked_before_fast_injection(self) -> None: + temp_root = ROOT / ".test_tmp" + temp_root.mkdir(exist_ok=True) + temp_dir = temp_root / f"chunked-inject-{uuid.uuid4().hex}" + temp_dir.mkdir() + try: + log_path = temp_dir / "fast_proxy.jsonl" + raw_body = json.dumps({"model": "gpt-test", "stream": True}).encode("utf-8") + chunked_body = b"%X\r\n%s\r\n0\r\n\r\n" % (len(raw_body), raw_body) + connection = FakeConnection() + handler = FastProxyHandler.__new__(FastProxyHandler) + handler.command = "POST" + handler.path = "/v1/responses" + handler.headers = { + "Content-Type": "application/json", + "Transfer-Encoding": "chunked", + } + handler.rfile = BytesIO(chunked_body) + handler.server = SimpleNamespace( + proxy_base="/v1", + upstream_base_path="/v1", + upstream_netloc="api.example.test", + service_tier="priority", + service_tier_policy="inject_missing", + service_tier_effective_policy="inject_missing", + log_path=log_path, + log_lock=threading.Lock(), + verbose=False, + open_connection=lambda: connection, + ) + handler.forward_response = lambda _response: None + handler.respond_bad_gateway = lambda: None + + FastProxyHandler.proxy(handler) + event = json.loads(log_path.read_text(encoding="utf-8")) + finally: + shutil.rmtree(temp_dir) + + forwarded = json.loads(connection.body) + self.assertEqual(forwarded["service_tier"], "priority") + self.assertEqual(connection.headers["Content-Length"], str(len(connection.body))) + self.assertNotIn("Transfer-Encoding", connection.headers) + self.assertTrue(event["service_tier_injected"]) + self.assertEqual(event["service_tier_before"], "") + self.assertEqual(event["service_tier_after"], "priority") + self.assertIsNone(event["json_error"]) + + def test_read_chunked_request_body_handles_extensions_and_trailers(self) -> None: + body = read_chunked_request_body(BytesIO(b"5;foo=bar\r\nhello\r\n0\r\nx-test: ok\r\n\r\n")) + + self.assertEqual(body, b"hello") + def test_client_disconnect_during_stream_is_logged_without_bad_gateway(self) -> None: temp_root = ROOT / ".test_tmp" temp_root.mkdir(exist_ok=True)