Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions headroom/image/onnx_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import numpy as np

from headroom.image.trained_router import ImageSignals, RouteDecision, Technique
from headroom.onnx_runtime import create_cpu_session_options

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -63,7 +64,9 @@ def _load_classifier(self) -> None:

model_path = hf_hub_download(_TECHNIQUE_ROUTER_REPO, "model_quantized.onnx")
self._classifier_session = ort.InferenceSession(
model_path, providers=["CPUExecutionProvider"]
model_path,
create_cpu_session_options(ort),
providers=["CPUExecutionProvider"],
)

tokenizer_path = hf_hub_download(_TECHNIQUE_ROUTER_REPO, "tokenizer.json")
Expand Down Expand Up @@ -95,7 +98,11 @@ def _load_siglip(self) -> None:
logger.info("Loading SigLIP ONNX INT8 image encoder...")

model_path = hf_hub_download(_SIGLIP_ENCODER_REPO, "image_encoder_int8.onnx")
self._siglip_session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
self._siglip_session = ort.InferenceSession(
model_path,
create_cpu_session_options(ort),
providers=["CPUExecutionProvider"],
)

embeddings_path = hf_hub_download(_SIGLIP_ENCODER_REPO, "text_embeddings.npz")
loaded = np.load(embeddings_path)
Expand Down
13 changes: 9 additions & 4 deletions headroom/memory/adapters/embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import numpy as np

from headroom.models.config import ML_MODEL_DEFAULTS
from headroom.onnx_runtime import create_cpu_session_options

if TYPE_CHECKING:
from sentence_transformers import SentenceTransformer
Expand Down Expand Up @@ -311,10 +312,14 @@ def _load_model(self) -> None:
model_path = hf_hub_download(self.ONNX_REPO, "model.onnx")
tok_path = hf_hub_download(self.ONNX_REPO, "tokenizer.json")

# Set thread count to avoid pthread_setaffinity_np errors in Docker containers
sess_options = ort.SessionOptions()
sess_options.intra_op_num_threads = 1
sess_options.inter_op_num_threads = 1
# Keep a small thread pool for Docker compatibility and disable ORT's
# CPU memory arena/pattern caches so long-running proxy workers do not
# retain large anonymous heaps after embedding bursts.
sess_options = create_cpu_session_options(
ort,
intra_op_num_threads=1,
inter_op_num_threads=1,
)
self._session = ort.InferenceSession(
model_path, sess_options, providers=["CPUExecutionProvider"]
)
Expand Down
51 changes: 51 additions & 0 deletions headroom/onnx_runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""ONNX Runtime helpers for long-running Headroom processes."""

from __future__ import annotations

import ctypes
import sys
from typing import Any


def create_cpu_session_options(
ort: Any,
*,
intra_op_num_threads: int | None = None,
inter_op_num_threads: int | None = None,
) -> Any:
"""Create CPU-oriented ONNX Runtime session options.

Headroom runs as a long-lived proxy process, so we bias toward predictable
memory usage over peak ONNX throughput. Disabling ORT's CPU arena and memory
pattern caches reduces retained anonymous RSS after variable-size inference
workloads, which is especially important on small VMs.
"""
sess_options = ort.SessionOptions()

if intra_op_num_threads is not None:
sess_options.intra_op_num_threads = intra_op_num_threads
if inter_op_num_threads is not None:
sess_options.inter_op_num_threads = inter_op_num_threads

if hasattr(sess_options, "enable_cpu_mem_arena"):
sess_options.enable_cpu_mem_arena = False
if hasattr(sess_options, "enable_mem_pattern"):
sess_options.enable_mem_pattern = False

return sess_options


def trim_process_heap() -> bool:
"""Ask glibc to return unused heap pages to the OS when available."""
if not sys.platform.startswith("linux"):
return False

try:
libc = ctypes.CDLL("libc.so.6")
except OSError:
return False

try:
return bool(libc.malloc_trim(0))
except Exception:
return False
25 changes: 17 additions & 8 deletions headroom/transforms/kompress_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@

from __future__ import annotations

import gc
import logging
import threading
from dataclasses import dataclass
from typing import Any

from ..config import TransformResult
from ..onnx_runtime import create_cpu_session_options, trim_process_heap
from ..tokenizer import Tokenizer
from .base import Transform

Expand Down Expand Up @@ -174,7 +176,11 @@ def _load_kompress_onnx(model_id: str) -> tuple[Any, Any, str]:
logger.info("Downloading Kompress ONNX model from %s ...", model_id)
onnx_path = hf_hub_download(model_id, "onnx/kompress-int8.onnx")

session = ort.InferenceSession(onnx_path)
session = ort.InferenceSession(
onnx_path,
create_cpu_session_options(ort),
providers=["CPUExecutionProvider"],
)
model = _OnnxModel(session)
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")

Expand Down Expand Up @@ -264,14 +270,17 @@ def unload_kompress_model(model_id: str | None = None) -> bool:
else:
return False

try:
import torch
try:
import torch

if torch.cuda.is_available():
torch.cuda.empty_cache()
except ImportError:
pass
return True
if torch.cuda.is_available():
torch.cuda.empty_cache()
except ImportError:
pass

gc.collect()
trim_process_heap()
return True


# ── Compressor ────────────────────────────────────────────────────────
Expand Down
43 changes: 43 additions & 0 deletions tests/test_onnx_runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from headroom.onnx_runtime import create_cpu_session_options


class _FakeSessionOptions:
def __init__(self):
self.intra_op_num_threads = None
self.inter_op_num_threads = None
self.enable_cpu_mem_arena = True
self.enable_mem_pattern = True


class _FakeOrt:
SessionOptions = _FakeSessionOptions


class _FakeSessionOptionsWithoutToggles:
def __init__(self):
self.intra_op_num_threads = None
self.inter_op_num_threads = None


class _FakeOrtWithoutToggles:
SessionOptions = _FakeSessionOptionsWithoutToggles


def test_create_cpu_session_options_disables_retention_features():
options = create_cpu_session_options(
_FakeOrt,
intra_op_num_threads=1,
inter_op_num_threads=2,
)

assert options.intra_op_num_threads == 1
assert options.inter_op_num_threads == 2
assert options.enable_cpu_mem_arena is False
assert options.enable_mem_pattern is False


def test_create_cpu_session_options_handles_older_session_options():
options = create_cpu_session_options(_FakeOrtWithoutToggles)

assert options.intra_op_num_threads is None
assert options.inter_op_num_threads is None
4 changes: 1 addition & 3 deletions tests/test_proxy_codex_route_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,7 @@ async def aclose(self) -> None:
),
],
)
def test_codex_responses_subpath_passthrough_derives_chatgpt_routing_from_jwt(
path, expected_url
):
def test_codex_responses_subpath_passthrough_derives_chatgpt_routing_from_jwt(path, expected_url):
class FakeAsyncClient:
def __init__(self) -> None:
self.calls: list[tuple[str, str, dict[str, str]]] = []
Expand Down
22 changes: 11 additions & 11 deletions tests/test_proxy_google_cloudcode_route_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,7 @@ async def fake_stream(self, url, _headers, _body, provider, model, *_args, **_kw
monkeypatch.setattr(HeadroomProxy, "_stream_response", fake_stream)

with TestClient(
create_app(
ProxyConfig(optimize=False, cloudcode_api_url="https://cloudcode-proxy.test/v1")
)
create_app(ProxyConfig(optimize=False, cloudcode_api_url="https://cloudcode-proxy.test/v1"))
) as client:
response = client.post(
"/v1/v1internal:streamGenerateContent",
Expand Down Expand Up @@ -150,9 +148,7 @@ async def fake_stream(self, url, _headers, _body, provider, model, *_args, **_kw
monkeypatch.setattr(HeadroomProxy, "_stream_response", fake_stream)

with TestClient(
create_app(
ProxyConfig(optimize=False, cloudcode_api_url="https://cloudcode-proxy.test")
)
create_app(ProxyConfig(optimize=False, cloudcode_api_url="https://cloudcode-proxy.test"))
) as client:
response = client.post(
"/v1internal:streamGenerateContent",
Expand All @@ -175,9 +171,7 @@ async def fake_stream(self, url, _headers, _body, provider, model, *_args, **_kw
monkeypatch.setattr(HeadroomProxy, "_stream_response", fake_stream)

with TestClient(
create_app(
ProxyConfig(optimize=False, cloudcode_api_url="https://cloudcode-proxy.test")
)
create_app(ProxyConfig(optimize=False, cloudcode_api_url="https://cloudcode-proxy.test"))
) as client:
first = client.post(
"/v1internal:streamGenerateContent",
Expand All @@ -193,6 +187,12 @@ async def fake_stream(self, url, _headers, _body, provider, model, *_args, **_kw
)

assert first.status_code == 200
assert first.json()["url"] == "https://cloudcode-proxy.test/v1internal:streamGenerateContent?alt=sse"
assert (
first.json()["url"]
== "https://cloudcode-proxy.test/v1internal:streamGenerateContent?alt=sse"
)
assert second.status_code == 200
assert second.json()["url"] == "https://cloudcode-pa.googleapis.com/v1internal:streamGenerateContent?alt=sse"
assert (
second.json()["url"]
== "https://cloudcode-pa.googleapis.com/v1internal:streamGenerateContent?alt=sse"
)
4 changes: 1 addition & 3 deletions tests/test_proxy_streaming_ratelimit_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,7 @@ async def test_upstream_http_error_preserves_status_body_and_metrics(self, monke
"content-length": "42",
}
)
mock_response.aread = AsyncMock(
return_value=b'{"error":{"message":"capacity exhausted"}}'
)
mock_response.aread = AsyncMock(return_value=b'{"error":{"message":"capacity exhausted"}}')
mock_response.aclose = AsyncMock()

mock_request = MagicMock()
Expand Down
Loading