Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions src/knowledge_service/clients/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,14 @@ def __init__(
model: str,
api_key: str,
registry: DomainRegistry | None = None,
max_concurrent: int = 4,
) -> None:
super().__init__(base_url, model, api_key, read_timeout=600.0)
self._registry = registry
# Cap concurrent LLM calls so an aegis ingestion burst doesn't queue
# past the read timeout inside Ollama. Acquired around the POST itself
# (not the whole retry loop) so backoff sleep doesn't hold a slot.
self._sem = asyncio.Semaphore(max_concurrent)
self._prompt_builder = None
if registry is not None:
from knowledge_service.clients.prompt_builder import PromptBuilder # noqa: PLC0415
Expand All @@ -111,13 +116,14 @@ async def _post_chat(self, prompt: str) -> str | None:

for attempt in range(_EXTRACT_MAX_RETRIES + 1):
try:
response = await self._client.post(
"/v1/chat/completions",
json={
"model": self._model,
"messages": [{"role": "user", "content": prompt}],
},
)
async with self._sem:
response = await self._client.post(
"/v1/chat/completions",
json={
"model": self._model,
"messages": [{"role": "user", "content": prompt}],
},
)
response.raise_for_status()
return response.json()["choices"][0]["message"]["content"]
except httpx.HTTPStatusError as exc:
Expand Down
6 changes: 6 additions & 0 deletions src/knowledge_service/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ class Settings(BaseSettings):
max_chunks: int = 50
embed_batch_size: int = 20
entity_cache_max_size: int = 1000
# Cap concurrent extraction LLM calls. Without this, a burst of N ingestion
# jobs (typical aegis daily arxiv pull is 30–50) fires N parallel requests
# at qwen3, which serves ~2–4 at a time on asif; the tail queues inside
# Ollama past the 600s read timeout and the whole batch falls into a
# retry-cascade. Pick 4 to match observed parallelism.
extraction_max_concurrent: int = 4

# Ingestion pipeline
spacy_data_dir: str = "/app/data/spacy"
Expand Down
1 change: 1 addition & 0 deletions src/knowledge_service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]:
model=settings.llm_chat_model,
api_key=settings.llm_api_key,
registry=domain_registry,
max_concurrent=settings.extraction_max_concurrent,
)
app.state.extraction_client = extraction_client
app.state.embedding_client = embedding_client
Expand Down
53 changes: 52 additions & 1 deletion tests/test_extraction_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import json

import httpx
Expand All @@ -6,11 +7,14 @@
from knowledge_service.clients.llm import ExtractionClient
from knowledge_service.models import EntityInput, TripleInput

# Capture the real asyncio.sleep before any autouse fixture rebinds it; concurrency
# tests need a real yield to observe overlap.
_REAL_SLEEP = asyncio.sleep


@pytest.fixture(autouse=True)
def _skip_retry_backoff(monkeypatch):
"""Monkeypatch asyncio.sleep to a no-op so retry tests don't actually wait."""
import asyncio

async def _nosleep(_seconds):
return None
Expand Down Expand Up @@ -420,3 +424,50 @@ async def test_no_auth_header_when_key_empty(self, httpx_mock):
headers = httpx_mock.get_requests()[0].headers
assert "authorization" not in headers
await client.close()


class TestConcurrencyCap:
"""ExtractionClient must bound concurrent LLM calls.

Regression for the prod cascade: an aegis batch of 30+ ingestion jobs
all hit qwen3 simultaneously, queued inside Ollama, and timed out at
the 600s read boundary together. The semaphore moves the queue from
qwen3 back into KS, where the read timeout doesn't apply.
"""

async def test_inflight_never_exceeds_cap(self, monkeypatch):
cap = 4
burst = 12 # ≥ 3× the cap so violations are obvious if uncapped

inflight = 0
max_inflight = 0

async def trace_post(*_args, **_kwargs):
nonlocal inflight, max_inflight
inflight += 1
max_inflight = max(max_inflight, inflight)
try:
# Real sleep to keep slots occupied long enough for overlap.
await _REAL_SLEEP(0.02)
finally:
inflight -= 1
return httpx.Response(
200,
json=_make_combined_response(entities=[], relations=[]),
request=httpx.Request("POST", _CHAT_URL),
)

client = ExtractionClient(
base_url=_BASE, model="qwen3:14b", api_key=_KEY, max_concurrent=cap
)
monkeypatch.setattr(client._client, "post", trace_post)

await asyncio.gather(*[client._post_chat("prompt") for _ in range(burst)])

assert max_inflight <= cap, f"max_inflight={max_inflight} exceeded cap={cap}"
await client.close()

async def test_default_cap_is_set(self):
client = ExtractionClient(base_url=_BASE, model="qwen3:14b", api_key=_KEY)
assert client._sem._value >= 1
await client.close()
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading