From 123b4f41b8661ad3d72bf5070067c117bdfbd272 Mon Sep 17 00:00:00 2001 From: Arie Radle Date: Thu, 4 Jun 2026 10:23:27 +0000 Subject: [PATCH 01/12] =?UTF-8?q?feat:=20SHEK-16=20=E2=80=94=20in-cluster?= =?UTF-8?q?=20Kubernetes=20auto-discovery=20from=20ConfigMap?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit New module shekel/integrations/kubernetes.py: - is_k8s_environment(): detects KUBERNETES_SERVICE_HOST + SHEKEL_BUDGET_NAME - _fetch_configmap(): loads shekel-budget-{name} from the pod's namespace via kubernetes.client.CoreV1Api; soft-imports kubernetes (no crash if absent) - apply_k8s_config(): applies ConfigMap values to Budget fields where still None (priority: explicit kwarg > AGENT_BUDGET_USD env var > ConfigMap) - KubernetesPoller: daemon thread that polls paused key every SHEKEL_POLL_INTERVAL_SECONDS (default 10s); sets _paused_externally Budget._record_spend(): raises BudgetExceededError immediately when _paused_externally is True (before spend accumulation). Budget.__exit__ / __aexit__: stop the poller thread on context exit. pyproject.toml: add [k8s] extra (kubernetes>=28.0); add to [all]; add kubernetes mypy override. 36 tests; 100% coverage on kubernetes.py. Co-Authored-By: Claude Sonnet 4.6 --- pyproject.toml | 7 +- shekel/_budget.py | 31 ++ shekel/integrations/kubernetes.py | 170 +++++++ tests/test_kubernetes_integration.py | 683 +++++++++++++++++++++++++++ 4 files changed, 890 insertions(+), 1 deletion(-) create mode 100644 shekel/integrations/kubernetes.py create mode 100644 tests/test_kubernetes_integration.py diff --git a/pyproject.toml b/pyproject.toml index 8433387..e571de8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,7 +81,8 @@ huggingface = ["huggingface-hub>=0.20.0"] otel = ["opentelemetry-api>=1.0.0"] redis = ["redis>=4.0.0"] autogen = ["pyautogen>=0.2.0,<0.3.0"] -all = ["openai>=1.0.0", "anthropic>=0.7.0", "langfuse>=2.0.0", "litellm>=1.0.0", "google-genai>=1.0.0", "huggingface-hub>=0.20.0", "opentelemetry-api>=1.0.0", "redis>=4.0.0", "pyautogen>=0.2.0,<0.3.0"] +k8s = ["kubernetes>=28.0"] +all = ["openai>=1.0.0", "anthropic>=0.7.0", "langfuse>=2.0.0", "litellm>=1.0.0", "google-genai>=1.0.0", "huggingface-hub>=0.20.0", "opentelemetry-api>=1.0.0", "redis>=4.0.0", "pyautogen>=0.2.0,<0.3.0", "kubernetes>=28.0"] all-models = ["openai>=1.0.0", "anthropic>=0.7.0", "langfuse>=2.0.0", "tokencost>=0.1.0"] cli = ["click>=8.0.0"] dev = [ @@ -215,6 +216,10 @@ ignore_missing_imports = true module = ["agents", "agents.*"] ignore_missing_imports = true +[[tool.mypy.overrides]] +module = ["kubernetes", "kubernetes.*"] +ignore_missing_imports = true + [[tool.mypy.overrides]] module = ["_pytest", "_pytest.*"] follow_imports = "skip" diff --git a/shekel/_budget.py b/shekel/_budget.py index 9b181fc..17dd7ee 100644 --- a/shekel/_budget.py +++ b/shekel/_budget.py @@ -300,6 +300,23 @@ def __init__( self._chain_budgets: dict[str, ComponentBudget] = {} self._runtime: Any = None + # --- Kubernetes auto-discovery (SHEK-16) --- + self._paused_externally: bool = False + self._k8s_poller: Any = None + self._per_pod_budget: Any = None + self._k8s_redis_backend: Any = None + self._k8s_redis_name: str | None = None + self._k8s_flush_every_usd: float | None = None + self._k8s_flush_every_seconds: float | None = None + self._k8s_scope_mode: str | None = None + self._k8s_scope_group_by: str | None = None + try: + from shekel.integrations.kubernetes import apply_k8s_config # noqa: PLC0415 + + apply_k8s_config(self) + except Exception: + pass + # ------------------------------------------------------------------ # Internal state reset # ------------------------------------------------------------------ @@ -470,6 +487,8 @@ def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> None: if self._runtime is not None: self._runtime.release() self._runtime = None + if self._k8s_poller is not None: + self._k8s_poller.stop() _patch.remove_patches() # returning None (not False) — never suppress exceptions @@ -610,6 +629,8 @@ async def __aexit__(self, exc_type: object, exc_val: object, exc_tb: object) -> if self._runtime is not None: self._runtime.release() self._runtime = None + if self._k8s_poller is not None: + self._k8s_poller.stop() _patch.remove_patches() # ------------------------------------------------------------------ @@ -634,6 +655,16 @@ def reset(self) -> None: # ------------------------------------------------------------------ def _record_spend(self, cost: float, model: str, tokens: dict[str, int]) -> None: + if self._paused_externally: + from shekel.exceptions import BudgetExceededError # noqa: PLC0415 + + raise BudgetExceededError( + spent=self._spent, + limit=self.max_usd or 0.0, + model=model, + tokens=tokens, + ) + # Parent locking: cannot record spend while a child budget is active if self.active_child is not None: raise RuntimeError( diff --git a/shekel/integrations/kubernetes.py b/shekel/integrations/kubernetes.py new file mode 100644 index 0000000..ceb0cc5 --- /dev/null +++ b/shekel/integrations/kubernetes.py @@ -0,0 +1,170 @@ +"""Kubernetes in-cluster auto-discovery for shekel budget configuration (SHEK-16). + +When both KUBERNETES_SERVICE_HOST and SHEKEL_BUDGET_NAME are set, Budget.__init__ +loads its configuration from the ConfigMap shekel-budget-{name} in the pod's +namespace. A background daemon thread polls the ConfigMap's paused key at +SHEKEL_POLL_INTERVAL_SECONDS (default 10) to implement the kill-switch. + +Config priority (lowest → highest): ConfigMap < env var < explicit kwarg. +""" + +from __future__ import annotations + +import logging +import os +import threading +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from shekel._budget import Budget + +logger = logging.getLogger(__name__) + +_SA_NAMESPACE_FILE = "/var/run/secrets/kubernetes.io/serviceaccount/namespace" + + +def is_k8s_environment() -> bool: + return bool(os.environ.get("KUBERNETES_SERVICE_HOST") and os.environ.get("SHEKEL_BUDGET_NAME")) + + +def _read_namespace() -> str: + ns = os.environ.get("SHEKEL_BUDGET_NAMESPACE") + if ns: + return ns + try: + with open(_SA_NAMESPACE_FILE) as f: + return f.read().strip() + except OSError: + return "default" + + +def _fetch_configmap(budget_name: str, namespace: str) -> dict[str, str] | None: + try: + import kubernetes + except ImportError: + logger.warning( + "shekel[k8s]: 'kubernetes' package not installed — skipping K8s config discovery." + " Install with: pip install shekel[k8s]" + ) + return None + + try: + kubernetes.config.load_incluster_config() + v1 = kubernetes.client.CoreV1Api() + cm = v1.read_namespaced_config_map( + name=f"shekel-budget-{budget_name}", + namespace=namespace, + ) + return dict(cm.data or {}) + except Exception as exc: + logger.warning("shekel[k8s]: Failed to load ConfigMap for %r: %s", budget_name, exc) + return None + + +def apply_k8s_config(budget: Budget) -> None: + """Load K8s ConfigMap and apply values to *budget* (mutates in place). + + Called at the end of Budget.__init__. Only fills fields that are still + None — explicit kwargs and env vars take precedence. + """ + if not is_k8s_environment(): + return + + budget_name = os.environ["SHEKEL_BUDGET_NAME"] + namespace = _read_namespace() + cm = _fetch_configmap(budget_name, namespace) + + if cm is None: + return + + # --- Kill-switch (immediate, before poll thread starts) --- + if cm.get("paused") == "true": + budget._paused_externally = True + + # --- max_usd: env var > ConfigMap --- + if budget.max_usd is None: + env_val = os.environ.get("AGENT_BUDGET_USD") + if env_val: + budget.max_usd = float(env_val) + elif "max_usd" in cm: + budget.max_usd = float(cm["max_usd"]) + + # --- warn_at --- + if budget.warn_at is None and "warn_at" in cm: + budget.warn_at = float(cm["warn_at"]) + + # --- max_llm_calls --- + if budget.max_llm_calls is None and "max_llm_calls" in cm: + budget.max_llm_calls = int(cm["max_llm_calls"]) + + # --- fallback --- + if budget.fallback is None and "fallback_model" in cm and "fallback_at_pct" in cm: + budget.fallback = { + "model": cm["fallback_model"], + "at_pct": float(cm["fallback_at_pct"]), + } + + # --- per_pod_cap --- + if "per_pod_cap" in cm: + from shekel._budget import Budget as _Budget # noqa: PLC0415 + + budget._per_pod_budget = _Budget(max_usd=float(cm["per_pod_cap"])) + + # --- Redis backend --- + if cm.get("backend") == "redis": + redis_url = os.environ.get("REDIS_URL") + if redis_url: + try: + from shekel.backends.redis import RedisBackend # noqa: PLC0415 + + budget._k8s_redis_backend = RedisBackend(url=redis_url) + budget._k8s_redis_name = cm.get("redis_key", f"shekel:{namespace}:{budget_name}") + except ImportError: + logger.warning( + "shekel[k8s]: 'redis' package not installed — skipping Redis backend." + ) + + # --- SHEK-17 fields (stored for spend reporter) --- + budget._k8s_flush_every_usd = float(cm["flush_every_usd"]) if "flush_every_usd" in cm else None + budget._k8s_flush_every_seconds = ( + float(cm["flush_every_seconds"]) if "flush_every_seconds" in cm else None + ) + budget._k8s_scope_mode = cm.get("scope_mode") + budget._k8s_scope_group_by = cm.get("scope_group_by") + + # --- Start background kill-switch poller --- + interval = float(os.environ.get("SHEKEL_POLL_INTERVAL_SECONDS", "10")) + poller = KubernetesPoller(budget, budget_name, namespace, interval) + poller.start() + budget._k8s_poller = poller + + +class KubernetesPoller(threading.Thread): + """Daemon thread that polls the ConfigMap's *paused* key. + + Sets budget._paused_externally so the next LLM call raises BudgetExceededError + within one poll interval of the operator setting paused=true. + """ + + def __init__( + self, + budget: Budget, + budget_name: str, + namespace: str, + interval: float, + ) -> None: + super().__init__(daemon=True, name=f"shekel-k8s-poller-{budget_name}") + self._budget = budget + self._budget_name = budget_name + self._namespace = namespace + self._interval = interval + self._stop_event = threading.Event() + + def stop(self) -> None: + self._stop_event.set() + + def run(self) -> None: + while not self._stop_event.wait(self._interval): + cm = _fetch_configmap(self._budget_name, self._namespace) + if cm is not None: + self._budget._paused_externally = cm.get("paused") == "true" diff --git a/tests/test_kubernetes_integration.py b/tests/test_kubernetes_integration.py new file mode 100644 index 0000000..bda1b04 --- /dev/null +++ b/tests/test_kubernetes_integration.py @@ -0,0 +1,683 @@ +"""Tests for SHEK-16: in-cluster Kubernetes auto-discovery from ConfigMap. + +All tests mock kubernetes.client.CoreV1Api — no live cluster required. +""" + +from __future__ import annotations + +import sys +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_configmap(data: dict[str, str]) -> MagicMock: + cm = MagicMock() + cm.data = data + return cm + + +def _make_k8s_mock(configmap_data: dict[str, str]) -> MagicMock: + """Return a mock kubernetes module whose CoreV1Api returns the given ConfigMap.""" + k8s = MagicMock() + api_instance = MagicMock() + api_instance.read_namespaced_config_map.return_value = _make_configmap(configmap_data) + k8s.client.CoreV1Api.return_value = api_instance + k8s.config.load_incluster_config = MagicMock() + return k8s + + +def _budget_with_k8s( + configmap_data: dict[str, str], + extra_env: dict[str, str] | None = None, + budget_kwargs: dict[str, Any] | None = None, + namespace: str = "default", +) -> Any: + """Create a Budget with K8s env vars set and a mocked CoreV1Api.""" + from shekel._budget import Budget + + k8s_mock = _make_k8s_mock(configmap_data) + env = { + "KUBERNETES_SERVICE_HOST": "10.0.0.1", + "SHEKEL_BUDGET_NAME": "test-budget", + **(extra_env or {}), + } + + with patch.dict("os.environ", env, clear=False): + with patch.dict( + "sys.modules", + { + "kubernetes": k8s_mock, + "kubernetes.client": k8s_mock.client, + "kubernetes.config": k8s_mock.config, + }, + ): + with patch( + "builtins.open", + MagicMock( + return_value=MagicMock( + __enter__=MagicMock( + return_value=MagicMock(read=MagicMock(return_value=namespace)) + ), + __exit__=MagicMock(return_value=False), + ) + ), + ): + b = Budget(**(budget_kwargs or {})) + return b + + +# --------------------------------------------------------------------------- +# Detection +# --------------------------------------------------------------------------- + + +class TestK8sDetection: + def test_no_env_vars_skips_discovery(self) -> None: + """Without K8s env vars, Budget is created normally.""" + from shekel._budget import Budget + + with patch.dict("os.environ", {}, clear=False): + # Remove K8s vars if present + env = { + k: v + for k, v in __import__("os").environ.items() + if k not in ("KUBERNETES_SERVICE_HOST", "SHEKEL_BUDGET_NAME") + } + with patch.dict("os.environ", env, clear=True): + b = Budget(max_usd=1.00) + assert b._paused_externally is False + assert b._k8s_poller is None + + def test_only_service_host_skips_discovery(self) -> None: + """KUBERNETES_SERVICE_HOST alone (no SHEKEL_BUDGET_NAME) skips K8s path.""" + import os + + from shekel._budget import Budget + + env = {k: v for k, v in os.environ.items() if k != "SHEKEL_BUDGET_NAME"} + env["KUBERNETES_SERVICE_HOST"] = "10.0.0.1" + with patch.dict("os.environ", env, clear=True): + b = Budget(max_usd=1.00) + assert b._k8s_poller is None + + def test_only_budget_name_skips_discovery(self) -> None: + """SHEKEL_BUDGET_NAME alone (no KUBERNETES_SERVICE_HOST) skips K8s path.""" + import os + + from shekel._budget import Budget + + env = {k: v for k, v in os.environ.items() if k != "KUBERNETES_SERVICE_HOST"} + env["SHEKEL_BUDGET_NAME"] = "test-budget" + with patch.dict("os.environ", env, clear=True): + b = Budget(max_usd=1.00) + assert b._k8s_poller is None + + +# --------------------------------------------------------------------------- +# ConfigMap loading +# --------------------------------------------------------------------------- + + +class TestConfigMapLoading: + def test_max_usd_loaded_from_configmap(self) -> None: + b = _budget_with_k8s({"max_usd": "0.50"}) + assert b.max_usd == pytest.approx(0.50) + + def test_warn_at_loaded_from_configmap(self) -> None: + b = _budget_with_k8s({"max_usd": "1.00", "warn_at": "0.8"}) + assert b.warn_at == pytest.approx(0.8) + + def test_max_llm_calls_loaded_from_configmap(self) -> None: + b = _budget_with_k8s({"max_llm_calls": "10"}) + assert b.max_llm_calls == 10 + + def test_fallback_loaded_from_configmap(self) -> None: + b = _budget_with_k8s( + { + "max_usd": "1.00", + "fallback_model": "gpt-4o-mini", + "fallback_at_pct": "0.8", + } + ) + assert b.fallback == {"model": "gpt-4o-mini", "at_pct": pytest.approx(0.8)} + + def test_empty_configmap_leaves_budget_unchanged(self) -> None: + b = _budget_with_k8s({}) + assert b.max_usd is None + assert b.warn_at is None + + def test_configmap_name_uses_budget_name_env_var(self) -> None: + """ConfigMap fetched as shekel-budget-{SHEKEL_BUDGET_NAME}.""" + from shekel._budget import Budget + + k8s_mock = _make_k8s_mock({"max_usd": "1.00"}) + env = {"KUBERNETES_SERVICE_HOST": "10.0.0.1", "SHEKEL_BUDGET_NAME": "my-agent"} + + with patch.dict("os.environ", env, clear=False): + with patch.dict( + "sys.modules", + { + "kubernetes": k8s_mock, + "kubernetes.client": k8s_mock.client, + "kubernetes.config": k8s_mock.config, + }, + ): + with patch( + "builtins.open", + MagicMock( + return_value=MagicMock( + __enter__=MagicMock( + return_value=MagicMock(read=MagicMock(return_value="default")) + ), + __exit__=MagicMock(return_value=False), + ) + ), + ): + Budget() + + api = k8s_mock.client.CoreV1Api.return_value + api.read_namespaced_config_map.assert_called_once_with( + name="shekel-budget-my-agent", namespace="default" + ) + + +# --------------------------------------------------------------------------- +# Config priority +# --------------------------------------------------------------------------- + + +class TestConfigPriority: + def test_explicit_kwarg_overrides_configmap(self) -> None: + """Explicit max_usd=1.00 beats ConfigMap max_usd: 0.50.""" + b = _budget_with_k8s({"max_usd": "0.50"}, budget_kwargs={"max_usd": 1.00}) + assert b.max_usd == pytest.approx(1.00) + + def test_env_var_overrides_configmap(self) -> None: + """AGENT_BUDGET_USD=2.00 beats ConfigMap max_usd: 0.50.""" + b = _budget_with_k8s( + {"max_usd": "0.50"}, + extra_env={"AGENT_BUDGET_USD": "2.00"}, + ) + assert b.max_usd == pytest.approx(2.00) + + def test_explicit_kwarg_overrides_env_var(self) -> None: + """Explicit kwarg beats env var.""" + b = _budget_with_k8s( + {}, + extra_env={"AGENT_BUDGET_USD": "2.00"}, + budget_kwargs={"max_usd": 5.00}, + ) + assert b.max_usd == pytest.approx(5.00) + + def test_configmap_applied_when_no_kwarg_or_env(self) -> None: + """ConfigMap is used when neither kwarg nor env var is set.""" + import os + + env = {k: v for k, v in os.environ.items() if k != "AGENT_BUDGET_USD"} + b = _budget_with_k8s({"max_usd": "0.75"}, extra_env=env) + assert b.max_usd == pytest.approx(0.75) + + +# --------------------------------------------------------------------------- +# Namespace resolution +# --------------------------------------------------------------------------- + + +class TestNamespaceResolution: + def test_namespace_read_from_sa_file(self) -> None: + """Namespace is read from the ServiceAccount namespace file.""" + from shekel._budget import Budget + + k8s_mock = _make_k8s_mock({}) + env = {"KUBERNETES_SERVICE_HOST": "10.0.0.1", "SHEKEL_BUDGET_NAME": "b"} + + open_mock = MagicMock() + open_mock.return_value.__enter__ = MagicMock( + return_value=MagicMock(read=MagicMock(return_value="my-namespace")) + ) + open_mock.return_value.__exit__ = MagicMock(return_value=False) + + with patch.dict("os.environ", env, clear=False): + with patch.dict( + "sys.modules", + { + "kubernetes": k8s_mock, + "kubernetes.client": k8s_mock.client, + "kubernetes.config": k8s_mock.config, + }, + ): + with patch("builtins.open", open_mock): + Budget() + + api = k8s_mock.client.CoreV1Api.return_value + api.read_namespaced_config_map.assert_called_once_with( + name="shekel-budget-b", namespace="my-namespace" + ) + + def test_shekel_budget_namespace_env_overrides_sa_file(self) -> None: + """SHEKEL_BUDGET_NAMESPACE overrides the SA namespace file.""" + from shekel._budget import Budget + + k8s_mock = _make_k8s_mock({}) + env = { + "KUBERNETES_SERVICE_HOST": "10.0.0.1", + "SHEKEL_BUDGET_NAME": "b", + "SHEKEL_BUDGET_NAMESPACE": "override-ns", + } + + with patch.dict("os.environ", env, clear=False): + with patch.dict( + "sys.modules", + { + "kubernetes": k8s_mock, + "kubernetes.client": k8s_mock.client, + "kubernetes.config": k8s_mock.config, + }, + ): + Budget() + + api = k8s_mock.client.CoreV1Api.return_value + api.read_namespaced_config_map.assert_called_once_with( + name="shekel-budget-b", namespace="override-ns" + ) + + +# --------------------------------------------------------------------------- +# Kill-switch (paused flag) +# --------------------------------------------------------------------------- + + +class TestKillSwitch: + def test_paused_true_sets_flag(self) -> None: + b = _budget_with_k8s({"paused": "true"}) + assert b._paused_externally is True + + def test_paused_false_does_not_set_flag(self) -> None: + b = _budget_with_k8s({"paused": "false"}) + assert b._paused_externally is False + + def test_paused_missing_does_not_set_flag(self) -> None: + b = _budget_with_k8s({}) + assert b._paused_externally is False + + def test_paused_budget_raises_on_record_spend(self) -> None: + from shekel._budget import Budget + from shekel.exceptions import BudgetExceededError + + b = Budget(max_usd=1.00) + b._paused_externally = True + + with pytest.raises(BudgetExceededError): + b._record_spend(0.01, "gpt-4o-mini", {"input": 10, "output": 5}) + + # Paused check fires before spend is accumulated + assert b._spent == pytest.approx(0.0) + + def test_not_paused_does_not_raise_on_record_spend(self) -> None: + from shekel._budget import Budget + + b = Budget(max_usd=1.00) + b._paused_externally = False + # should not raise + b._record_spend(0.01, "gpt-4o-mini", {"input": 10, "output": 5}) + + +# --------------------------------------------------------------------------- +# Background poll thread +# --------------------------------------------------------------------------- + + +class TestPollThread: + def test_poll_thread_started_in_k8s_mode(self) -> None: + b = _budget_with_k8s({"max_usd": "1.00"}) + assert b._k8s_poller is not None + assert b._k8s_poller.is_alive() + b._k8s_poller.stop() + + def test_poll_thread_is_daemon(self) -> None: + b = _budget_with_k8s({"max_usd": "1.00"}) + assert b._k8s_poller.daemon is True + b._k8s_poller.stop() + + def test_poll_thread_stopped_on_budget_exit(self) -> None: + from shekel._budget import Budget + + k8s_mock = _make_k8s_mock({"max_usd": "1.00"}) + env = {"KUBERNETES_SERVICE_HOST": "10.0.0.1", "SHEKEL_BUDGET_NAME": "b"} + + with patch.dict("os.environ", env, clear=False): + with patch.dict( + "sys.modules", + { + "kubernetes": k8s_mock, + "kubernetes.client": k8s_mock.client, + "kubernetes.config": k8s_mock.config, + }, + ): + with patch( + "builtins.open", + MagicMock( + return_value=MagicMock( + __enter__=MagicMock( + return_value=MagicMock(read=MagicMock(return_value="default")) + ), + __exit__=MagicMock(return_value=False), + ) + ), + ): + with Budget() as b: + poller = b._k8s_poller + + assert poller._stop_event.is_set() + + def test_poll_updates_paused_flag(self) -> None: + """Poller sets _paused_externally when ConfigMap changes to paused=true.""" + from shekel._budget import Budget + from shekel.integrations.kubernetes import KubernetesPoller + + b = Budget(max_usd=1.00) + b._paused_externally = False + + k8s_mock = _make_k8s_mock({"paused": "true"}) + + with patch.dict( + "sys.modules", + { + "kubernetes": k8s_mock, + "kubernetes.client": k8s_mock.client, + "kubernetes.config": k8s_mock.config, + }, + ): + poller = KubernetesPoller(b, "test", "default", interval=0.01) + poller.start() + poller._stop_event.wait(timeout=0.5) + poller.stop() + poller.join(timeout=1.0) + + assert b._paused_externally is True + + def test_poll_clears_paused_flag_when_unpaused(self) -> None: + """Poller clears _paused_externally when ConfigMap changes to paused=false.""" + from shekel._budget import Budget + from shekel.integrations.kubernetes import KubernetesPoller + + b = Budget(max_usd=1.00) + b._paused_externally = True + + k8s_mock = _make_k8s_mock({"paused": "false"}) + + with patch.dict( + "sys.modules", + { + "kubernetes": k8s_mock, + "kubernetes.client": k8s_mock.client, + "kubernetes.config": k8s_mock.config, + }, + ): + poller = KubernetesPoller(b, "test", "default", interval=0.01) + poller.start() + poller._stop_event.wait(timeout=0.5) + poller.stop() + poller.join(timeout=1.0) + + assert b._paused_externally is False + + +# --------------------------------------------------------------------------- +# kubernetes package absent +# --------------------------------------------------------------------------- + + +class TestKubernetesPackageAbsent: + def test_missing_kubernetes_package_no_crash(self) -> None: + """If kubernetes is not installed, Budget construction doesn't crash.""" + from shekel._budget import Budget + + env = {"KUBERNETES_SERVICE_HOST": "10.0.0.1", "SHEKEL_BUDGET_NAME": "b"} + + # Remove kubernetes from sys.modules entirely + modules_without_k8s = {k: v for k, v in sys.modules.items() if "kubernetes" not in k} + modules_without_k8s["kubernetes"] = None # type: ignore[assignment] + + with patch.dict("os.environ", env, clear=False): + with patch.dict("sys.modules", modules_without_k8s, clear=True): + b = Budget(max_usd=1.00) + + assert b.max_usd == pytest.approx(1.00) + assert b._k8s_poller is None + + def test_missing_kubernetes_logs_warning(self, caplog: pytest.LogCaptureFixture) -> None: + import logging + + from shekel._budget import Budget + + env = {"KUBERNETES_SERVICE_HOST": "10.0.0.1", "SHEKEL_BUDGET_NAME": "b"} + modules_without_k8s = {k: v for k, v in sys.modules.items() if "kubernetes" not in k} + modules_without_k8s["kubernetes"] = None # type: ignore[assignment] + + with caplog.at_level(logging.WARNING, logger="shekel.integrations.kubernetes"): + with patch.dict("os.environ", env, clear=False): + with patch.dict("sys.modules", modules_without_k8s, clear=True): + Budget(max_usd=1.00) + + assert any("kubernetes" in r.message.lower() for r in caplog.records) + + +# --------------------------------------------------------------------------- +# Redis backend activation +# --------------------------------------------------------------------------- + + +class TestRedisBackendActivation: + def test_redis_backend_activated_from_configmap(self) -> None: + """ConfigMap backend=redis + REDIS_URL → RedisBackend stored on budget.""" + from shekel._budget import Budget + + k8s_mock = _make_k8s_mock( + { + "backend": "redis", + "redis_key": "shekel:default:test-budget", + } + ) + env = { + "KUBERNETES_SERVICE_HOST": "10.0.0.1", + "SHEKEL_BUDGET_NAME": "test-budget", + "REDIS_URL": "redis://localhost:6379/0", + } + + redis_backend_mock = MagicMock() + with patch.dict("os.environ", env, clear=False): + with patch.dict( + "sys.modules", + { + "kubernetes": k8s_mock, + "kubernetes.client": k8s_mock.client, + "kubernetes.config": k8s_mock.config, + }, + ): + with patch( + "builtins.open", + MagicMock( + return_value=MagicMock( + __enter__=MagicMock( + return_value=MagicMock(read=MagicMock(return_value="default")) + ), + __exit__=MagicMock(return_value=False), + ) + ), + ): + with patch( + "shekel.backends.redis.RedisBackend", return_value=redis_backend_mock + ) as mock_cls: + b = Budget() + + mock_cls.assert_called_once_with(url="redis://localhost:6379/0") + assert b._k8s_redis_backend is redis_backend_mock + assert b._k8s_redis_name == "shekel:default:test-budget" + + def test_redis_backend_skipped_without_redis_url(self) -> None: + """ConfigMap backend=redis but no REDIS_URL → no RedisBackend.""" + import os + + env_no_redis = {k: v for k, v in os.environ.items() if k != "REDIS_URL"} + env_no_redis.update( + { + "KUBERNETES_SERVICE_HOST": "10.0.0.1", + "SHEKEL_BUDGET_NAME": "test-budget", + } + ) + + from shekel._budget import Budget + + k8s_mock = _make_k8s_mock({"backend": "redis"}) + + with patch.dict("os.environ", env_no_redis, clear=True): + with patch.dict( + "sys.modules", + { + "kubernetes": k8s_mock, + "kubernetes.client": k8s_mock.client, + "kubernetes.config": k8s_mock.config, + }, + ): + with patch( + "builtins.open", + MagicMock( + return_value=MagicMock( + __enter__=MagicMock( + return_value=MagicMock(read=MagicMock(return_value="default")) + ), + __exit__=MagicMock(return_value=False), + ) + ), + ): + b = Budget() + + assert not hasattr(b, "_k8s_redis_backend") or b._k8s_redis_backend is None + + +# --------------------------------------------------------------------------- +# Per-pod cap +# --------------------------------------------------------------------------- + + +class TestPerPodCap: + def test_per_pod_cap_stored_on_budget(self) -> None: + b = _budget_with_k8s({"per_pod_cap": "0.25"}) + assert hasattr(b, "_per_pod_budget") + assert b._per_pod_budget.max_usd == pytest.approx(0.25) + + +# --------------------------------------------------------------------------- +# SHEK-17 fields stored +# --------------------------------------------------------------------------- + + +class TestErrorPaths: + def test_configmap_fetch_error_logs_warning(self, caplog: pytest.LogCaptureFixture) -> None: + """CoreV1Api raising an exception logs a warning and returns None (no crash).""" + import logging + + from shekel._budget import Budget + + k8s_mock = MagicMock() + k8s_mock.config.load_incluster_config = MagicMock() + k8s_mock.client.CoreV1Api.return_value.read_namespaced_config_map.side_effect = ( + RuntimeError("timeout") + ) + env = {"KUBERNETES_SERVICE_HOST": "10.0.0.1", "SHEKEL_BUDGET_NAME": "b"} + + with caplog.at_level(logging.WARNING, logger="shekel.integrations.kubernetes"): + with patch.dict("os.environ", env, clear=False): + with patch.dict( + "sys.modules", + { + "kubernetes": k8s_mock, + "kubernetes.client": k8s_mock.client, + "kubernetes.config": k8s_mock.config, + }, + ): + with patch( + "builtins.open", + MagicMock( + return_value=MagicMock( + __enter__=MagicMock( + return_value=MagicMock(read=MagicMock(return_value="default")) + ), + __exit__=MagicMock(return_value=False), + ) + ), + ): + b = Budget(max_usd=1.00) + + assert b.max_usd == pytest.approx(1.00) + assert any("Failed to load ConfigMap" in r.message for r in caplog.records) + + def test_redis_import_error_logs_warning(self, caplog: pytest.LogCaptureFixture) -> None: + """Missing redis package during backend activation logs a warning (no crash).""" + import logging + + from shekel._budget import Budget + + k8s_mock = _make_k8s_mock({"backend": "redis"}) + env = { + "KUBERNETES_SERVICE_HOST": "10.0.0.1", + "SHEKEL_BUDGET_NAME": "b", + "REDIS_URL": "redis://localhost:6379/0", + } + modules_no_redis = { + k: v for k, v in sys.modules.items() if not k.startswith("shekel.backends.redis") + } + modules_no_redis["shekel.backends.redis"] = None # type: ignore[assignment] + + with caplog.at_level(logging.WARNING, logger="shekel.integrations.kubernetes"): + with patch.dict("os.environ", env, clear=False): + with patch.dict( + "sys.modules", + { + **{ + "kubernetes": k8s_mock, + "kubernetes.client": k8s_mock.client, + "kubernetes.config": k8s_mock.config, + }, + **modules_no_redis, + }, + ): + with patch( + "builtins.open", + MagicMock( + return_value=MagicMock( + __enter__=MagicMock( + return_value=MagicMock(read=MagicMock(return_value="default")) + ), + __exit__=MagicMock(return_value=False), + ) + ), + ): + Budget(max_usd=1.00) + + assert any("redis" in r.message.lower() for r in caplog.records) + + +class TestShek17Fields: + def test_flush_every_usd_stored(self) -> None: + b = _budget_with_k8s({"flush_every_usd": "0.10"}) + assert b._k8s_flush_every_usd == pytest.approx(0.10) + + def test_flush_every_seconds_stored(self) -> None: + b = _budget_with_k8s({"flush_every_seconds": "30"}) + assert b._k8s_flush_every_seconds == pytest.approx(30.0) + + def test_scope_mode_stored(self) -> None: + b = _budget_with_k8s({"scope_mode": "shared"}) + assert b._k8s_scope_mode == "shared" + + def test_scope_group_by_stored(self) -> None: + b = _budget_with_k8s({"scope_group_by": "team"}) + assert b._k8s_scope_group_by == "team" From ed73470a1febba8827ad2b0962dcf7a00957099c Mon Sep 17 00:00:00 2001 From: Arie Radle Date: Thu, 4 Jun 2026 12:29:19 +0000 Subject: [PATCH 02/12] =?UTF-8?q?feat:=20SHEK-17=20=E2=80=94=20periodic=20?= =?UTF-8?q?spend=20reporting=20to=20Kubernetes=20ConfigMap?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit New KubernetesSpendReporter daemon thread (shekel/integrations/kubernetes.py): - Active when ConfigMap has backend=k8s; skipped for backend=redis or absent - Accumulates cumulative LLM spend/calls under threading.Lock (never hold lock across network call) - Flush triggers: flush_every_seconds (time-based), flush_every_usd (USD threshold on delta since last flush), Budget.__exit__/__aexit__ (always, including on exception) - Patch-or-create ConfigMap shekel-spend-{HOSTNAME}: patch first, create on 404 ApiException; any failure logs WARNING and never raises to caller - After successful write updates _last_flush_spent so next flush computes correct delta; baseline unchanged on failure so full cumulative total retried - HOSTNAME absent → flush silently skipped - Correct labels: shekel.dev/spend-report, shekel.dev/budget, shekel.dev/group (omitted when SHEKEL_GROUP_VALUE is empty) Budget._record_spend: calls reporter.on_spend(cost) after each LLM call. Budget.__exit__ / __aexit__: calls reporter.flush_and_stop() on context exit. 41 new tests; 100% coverage on kubernetes.py. Co-Authored-By: Claude Sonnet 4.6 --- shekel/_budget.py | 9 +- shekel/integrations/kubernetes.py | 131 ++++++- tests/test_kubernetes_integration.py | 520 ++++++++++++++++++++++++++- 3 files changed, 657 insertions(+), 3 deletions(-) diff --git a/shekel/_budget.py b/shekel/_budget.py index 17dd7ee..48a64b4 100644 --- a/shekel/_budget.py +++ b/shekel/_budget.py @@ -300,9 +300,10 @@ def __init__( self._chain_budgets: dict[str, ComponentBudget] = {} self._runtime: Any = None - # --- Kubernetes auto-discovery (SHEK-16) --- + # --- Kubernetes auto-discovery (SHEK-16) / spend reporting (SHEK-17) --- self._paused_externally: bool = False self._k8s_poller: Any = None + self._k8s_reporter: Any = None self._per_pod_budget: Any = None self._k8s_redis_backend: Any = None self._k8s_redis_name: str | None = None @@ -487,6 +488,8 @@ def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> None: if self._runtime is not None: self._runtime.release() self._runtime = None + if self._k8s_reporter is not None: + self._k8s_reporter.flush_and_stop() if self._k8s_poller is not None: self._k8s_poller.stop() _patch.remove_patches() @@ -629,6 +632,8 @@ async def __aexit__(self, exc_type: object, exc_val: object, exc_tb: object) -> if self._runtime is not None: self._runtime.release() self._runtime = None + if self._k8s_reporter is not None: + self._k8s_reporter.flush_and_stop() if self._k8s_poller is not None: self._k8s_poller.stop() _patch.remove_patches() @@ -695,6 +700,8 @@ def _record_spend(self, cost: float, model: str, tokens: dict[str, int]) -> None self._check_warn() self._check_limit() self._check_call_limit() + if self._k8s_reporter is not None: + self._k8s_reporter.on_spend(cost) def _check_warn(self) -> None: effective_limit = self._effective_limit diff --git a/shekel/integrations/kubernetes.py b/shekel/integrations/kubernetes.py index ceb0cc5..8706926 100644 --- a/shekel/integrations/kubernetes.py +++ b/shekel/integrations/kubernetes.py @@ -1,10 +1,14 @@ -"""Kubernetes in-cluster auto-discovery for shekel budget configuration (SHEK-16). +"""Kubernetes in-cluster auto-discovery and spend reporting for shekel (SHEK-16/17). When both KUBERNETES_SERVICE_HOST and SHEKEL_BUDGET_NAME are set, Budget.__init__ loads its configuration from the ConfigMap shekel-budget-{name} in the pod's namespace. A background daemon thread polls the ConfigMap's paused key at SHEKEL_POLL_INTERVAL_SECONDS (default 10) to implement the kill-switch. +When ConfigMap has backend=k8s, a KubernetesSpendReporter daemon thread +periodically writes cumulative LLM spend to a shekel-spend-{pod} ConfigMap so +the controller can aggregate across pods. + Config priority (lowest → highest): ConfigMap < env var < explicit kwarg. """ @@ -13,6 +17,7 @@ import logging import os import threading +from datetime import datetime, timezone from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -132,6 +137,19 @@ def apply_k8s_config(budget: Budget) -> None: budget._k8s_scope_mode = cm.get("scope_mode") budget._k8s_scope_group_by = cm.get("scope_group_by") + # --- SHEK-17: Spend reporter (only when backend=k8s) --- + if cm.get("backend") == "k8s": + group_value = os.environ.get("SHEKEL_GROUP_VALUE", "") + reporter = KubernetesSpendReporter( + budget_name=budget_name, + namespace=namespace, + flush_every_seconds=budget._k8s_flush_every_seconds or 60.0, + flush_every_usd=budget._k8s_flush_every_usd, + group_value=group_value, + ) + reporter.start() + budget._k8s_reporter = reporter + # --- Start background kill-switch poller --- interval = float(os.environ.get("SHEKEL_POLL_INTERVAL_SECONDS", "10")) poller = KubernetesPoller(budget, budget_name, namespace, interval) @@ -168,3 +186,114 @@ def run(self) -> None: cm = _fetch_configmap(self._budget_name, self._namespace) if cm is not None: self._budget._paused_externally = cm.get("paused") == "true" + + +class KubernetesSpendReporter(threading.Thread): + """Daemon thread that flushes cumulative LLM spend to a Spend Report ConfigMap. + + Active when ConfigMap has backend=k8s. The controller reads spent_usd from + each pod's ConfigMap and aggregates across pods — no Redis required. + + Flush triggers (whichever fires first): + - flush_every_seconds elapsed since last flush (time-based) + - flush_every_usd accumulated since last flush (USD threshold) + - Budget.__exit__ / __aexit__ (always on context exit, including on exception) + """ + + def __init__( + self, + budget_name: str, + namespace: str, + flush_every_seconds: float = 60.0, + flush_every_usd: float | None = None, + group_value: str = "", + ) -> None: + super().__init__(daemon=True, name=f"shekel-k8s-reporter-{budget_name}") + self._budget_name = budget_name + self._namespace = namespace + self._flush_every_seconds = flush_every_seconds + self._flush_every_usd = flush_every_usd + self._group_value = group_value + self._lock = threading.Lock() + self._total_spent: float = 0.0 + self._total_calls: int = 0 + self._last_flush_spent: float = 0.0 + self._stop_event = threading.Event() + + def on_spend(self, cost: float) -> None: + """Called from Budget._record_spend after each LLM call.""" + with self._lock: + self._total_spent += cost + self._total_calls += 1 + delta = self._total_spent - self._last_flush_spent + if self._flush_every_usd is not None and delta >= self._flush_every_usd: + self._flush() + + def stop(self) -> None: + self._stop_event.set() + + def flush_and_stop(self) -> None: + """Stop the background thread and perform a final synchronous flush.""" + self._stop_event.set() + self._flush() + + def run(self) -> None: + while not self._stop_event.wait(self._flush_every_seconds): + self._flush() + + def _flush(self) -> None: + pod_name = os.environ.get("HOSTNAME") + if not pod_name: + return + + try: + import kubernetes # noqa: PLC0415 + except ImportError: # pragma: no cover — optional dependency + return + + with self._lock: + total_spent = self._total_spent + total_calls = self._total_calls + + cm_name = f"shekel-spend-{pod_name}" + labels: dict[str, str] = { + "shekel.dev/spend-report": "true", + "shekel.dev/budget": self._budget_name, + } + if self._group_value: + labels["shekel.dev/group"] = self._group_value + + body = { + "apiVersion": "v1", + "kind": "ConfigMap", + "metadata": { + "name": cm_name, + "namespace": self._namespace, + "labels": labels, + }, + "data": { + "spent_usd": str(total_spent), + "call_count": str(total_calls), + "last_updated": datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"), + "pod_name": pod_name, + }, + } + + success = False + try: + kubernetes.config.load_incluster_config() + v1 = kubernetes.client.CoreV1Api() + try: + v1.patch_namespaced_config_map(cm_name, self._namespace, body) + except kubernetes.client.ApiException as exc: + if exc.status == 404: + v1.create_namespaced_config_map(self._namespace, body) + else: + raise + success = True + except Exception as exc: + logger.warning("shekel[k8s]: Failed to flush spend report: %s", exc) + + if success: + with self._lock: + self._last_flush_spent = total_spent diff --git a/tests/test_kubernetes_integration.py b/tests/test_kubernetes_integration.py index bda1b04..fb3eb79 100644 --- a/tests/test_kubernetes_integration.py +++ b/tests/test_kubernetes_integration.py @@ -1,11 +1,13 @@ -"""Tests for SHEK-16: in-cluster Kubernetes auto-discovery from ConfigMap. +"""Tests for SHEK-16 / SHEK-17: Kubernetes auto-discovery and spend reporting. All tests mock kubernetes.client.CoreV1Api — no live cluster required. """ from __future__ import annotations +import contextlib import sys +import threading from typing import Any from unittest.mock import MagicMock, patch @@ -681,3 +683,519 @@ def test_scope_mode_stored(self) -> None: def test_scope_group_by_stored(self) -> None: b = _budget_with_k8s({"scope_group_by": "team"}) assert b._k8s_scope_group_by == "team" + + +# --------------------------------------------------------------------------- +# SHEK-17: KubernetesSpendReporter +# --------------------------------------------------------------------------- + + +@contextlib.contextmanager +def _k8s_sys_modules(k8s_mock: MagicMock): # type: ignore[return] + with patch.dict( + "sys.modules", + { + "kubernetes": k8s_mock, + "kubernetes.client": k8s_mock.client, + "kubernetes.config": k8s_mock.config, + }, + ): + yield + + +@contextlib.contextmanager +def _flush_env(k8s_mock: MagicMock, hostname: str = "test-pod"): # type: ignore[return] + with patch.dict("os.environ", {"HOSTNAME": hostname}, clear=False): + with _k8s_sys_modules(k8s_mock): + yield + + +def _make_api_exception_class(k8s_mock: MagicMock) -> type: + class FakeApiException(Exception): + def __init__(self, status: int = 500) -> None: + self.status = status + + k8s_mock.client.ApiException = FakeApiException + return FakeApiException + + +class TestKubernetesSpendReporter: + # ── Activation ──────────────────────────────────────────────────────── + + def test_reporter_not_started_when_backend_absent(self) -> None: + b = _budget_with_k8s({}) + assert b._k8s_reporter is None + + def test_reporter_not_started_when_backend_redis(self) -> None: + b = _budget_with_k8s({"backend": "redis"}, extra_env={"REDIS_URL": "redis://localhost"}) + assert b._k8s_reporter is None + + def test_reporter_started_when_backend_k8s(self) -> None: + b = _budget_with_k8s({"backend": "k8s"}) + assert b._k8s_reporter is not None + + def test_reporter_flush_every_seconds_from_configmap(self) -> None: + b = _budget_with_k8s({"backend": "k8s", "flush_every_seconds": "45"}) + assert b._k8s_reporter._flush_every_seconds == pytest.approx(45.0) + + def test_reporter_flush_every_usd_from_configmap(self) -> None: + b = _budget_with_k8s({"backend": "k8s", "flush_every_usd": "0.25"}) + assert b._k8s_reporter._flush_every_usd == pytest.approx(0.25) + + def test_reporter_flush_every_seconds_defaults_to_60(self) -> None: + b = _budget_with_k8s({"backend": "k8s"}) + assert b._k8s_reporter._flush_every_seconds == pytest.approx(60.0) + + def test_reporter_group_value_from_env(self) -> None: + b = _budget_with_k8s({"backend": "k8s"}, extra_env={"SHEKEL_GROUP_VALUE": "team-a"}) + assert b._k8s_reporter._group_value == "team-a" + + def test_reporter_group_value_empty_by_default(self) -> None: + import os + + env = {k: v for k, v in os.environ.items() if k != "SHEKEL_GROUP_VALUE"} + b = _budget_with_k8s({"backend": "k8s"}, extra_env=env) + assert b._k8s_reporter._group_value == "" + + # ── Spend accumulation ───────────────────────────────────────────────── + + def test_on_spend_accumulates_total_spent(self) -> None: + from shekel.integrations.kubernetes import KubernetesSpendReporter + + r = KubernetesSpendReporter("b", "ns") + r.on_spend(0.10) + r.on_spend(0.25) + assert r._total_spent == pytest.approx(0.35) + + def test_on_spend_accumulates_total_calls(self) -> None: + from shekel.integrations.kubernetes import KubernetesSpendReporter + + r = KubernetesSpendReporter("b", "ns") + r.on_spend(0.10) + r.on_spend(0.20) + assert r._total_calls == 2 + + def test_on_spend_no_flush_when_no_usd_threshold(self) -> None: + from shekel.integrations.kubernetes import KubernetesSpendReporter + + r = KubernetesSpendReporter("b", "ns", flush_every_usd=None) + with patch.object(r, "_flush") as mock_flush: + r.on_spend(999.99) + mock_flush.assert_not_called() + + # ── USD threshold ────────────────────────────────────────────────────── + + def test_usd_threshold_triggers_flush(self) -> None: + from shekel.integrations.kubernetes import KubernetesSpendReporter + + r = KubernetesSpendReporter("b", "ns", flush_every_usd=0.10) + with patch.object(r, "_flush") as mock_flush: + r.on_spend(0.11) + mock_flush.assert_called_once() + + def test_usd_threshold_not_triggered_below_threshold(self) -> None: + from shekel.integrations.kubernetes import KubernetesSpendReporter + + r = KubernetesSpendReporter("b", "ns", flush_every_usd=1.00) + with patch.object(r, "_flush") as mock_flush: + r.on_spend(0.50) + mock_flush.assert_not_called() + + def test_usd_threshold_exact_boundary_triggers_flush(self) -> None: + from shekel.integrations.kubernetes import KubernetesSpendReporter + + r = KubernetesSpendReporter("b", "ns", flush_every_usd=0.10) + with patch.object(r, "_flush") as mock_flush: + r.on_spend(0.10) + mock_flush.assert_called_once() + + def test_usd_threshold_delta_uses_last_flush_as_baseline(self) -> None: + """Delta is relative to _last_flush_spent, not zero.""" + from shekel.integrations.kubernetes import KubernetesSpendReporter + + r = KubernetesSpendReporter("b", "ns", flush_every_usd=0.50) + r._last_flush_spent = 0.40 + r._total_spent = 0.40 + + with patch.object(r, "_flush") as mock_flush: + r.on_spend(0.45) # total=0.85, delta=0.45 — below 0.50 + mock_flush.assert_not_called() + r.on_spend(0.10) # total=0.95, delta=0.55 — above 0.50 + mock_flush.assert_called_once() + + # ── Background flush thread ──────────────────────────────────────────── + + def test_run_calls_flush_on_interval(self) -> None: + from shekel.integrations.kubernetes import KubernetesSpendReporter + + r = KubernetesSpendReporter("b", "ns", flush_every_seconds=60.0) + flush_calls = [0] + r._flush = lambda: flush_calls.__setitem__(0, flush_calls[0] + 1) # type: ignore[assignment] + + with patch.object(r._stop_event, "wait", side_effect=iter([False, True])): + r.run() + + assert flush_calls[0] == 1 + + def test_run_passes_interval_to_wait(self) -> None: + from shekel.integrations.kubernetes import KubernetesSpendReporter + + r = KubernetesSpendReporter("b", "ns", flush_every_seconds=30.0) + r._flush = MagicMock() # type: ignore[assignment] + + with patch.object(r._stop_event, "wait", side_effect=iter([True])) as mock_wait: + r.run() + + mock_wait.assert_called_with(30.0) + + def test_run_stops_when_stop_event_set(self) -> None: + from shekel.integrations.kubernetes import KubernetesSpendReporter + + r = KubernetesSpendReporter("b", "ns") + r._flush = MagicMock() # type: ignore[assignment] + + with patch.object(r._stop_event, "wait", side_effect=iter([True])): + r.run() + + r._flush.assert_not_called() # type: ignore[union-attr] + + # ── Exit flush ───────────────────────────────────────────────────────── + + def test_exit_flush_on_sync_context_exit(self) -> None: + b = _budget_with_k8s({"backend": "k8s"}) + mock_flush_stop = MagicMock() + b._k8s_reporter.flush_and_stop = mock_flush_stop + + with patch("shekel._patch.remove_patches"): + b.__exit__(None, None, None) + + mock_flush_stop.assert_called_once() + + def test_exit_flush_on_exception_exit(self) -> None: + b = _budget_with_k8s({"backend": "k8s"}) + mock_flush_stop = MagicMock() + b._k8s_reporter.flush_and_stop = mock_flush_stop + + with patch("shekel._patch.remove_patches"): + b.__exit__(RuntimeError, RuntimeError("boom"), None) + + mock_flush_stop.assert_called_once() + + def test_exit_flush_on_async_context_exit(self) -> None: + import asyncio + + b = _budget_with_k8s({"backend": "k8s"}) + mock_flush_stop = MagicMock() + b._k8s_reporter.flush_and_stop = mock_flush_stop + + async def run() -> None: + with patch("shekel._patch.remove_patches"): + await b.__aexit__(None, None, None) + + asyncio.run(run()) + mock_flush_stop.assert_called_once() + + def test_exit_flush_on_async_exception_exit(self) -> None: + import asyncio + + b = _budget_with_k8s({"backend": "k8s"}) + mock_flush_stop = MagicMock() + b._k8s_reporter.flush_and_stop = mock_flush_stop + + async def run() -> None: + with patch("shekel._patch.remove_patches"): + await b.__aexit__(RuntimeError, RuntimeError("async boom"), None) + + asyncio.run(run()) + mock_flush_stop.assert_called_once() + + # ── _flush() — hostname guard ────────────────────────────────────────── + + def test_flush_skipped_when_hostname_absent(self) -> None: + import os + + k8s_mock = _make_k8s_mock({}) + from shekel.integrations.kubernetes import KubernetesSpendReporter + + r = KubernetesSpendReporter("b", "default") + r._total_spent = 0.50 + + env = {k: v for k, v in os.environ.items() if k != "HOSTNAME"} + with patch.dict("os.environ", env, clear=True): + with _k8s_sys_modules(k8s_mock): + r._flush() + + api = k8s_mock.client.CoreV1Api.return_value + assert not api.patch_namespaced_config_map.called + assert not api.create_namespaced_config_map.called + + # ── _flush() — patch-or-create logic ────────────────────────────────── + + def test_flush_patches_configmap_first(self) -> None: + k8s_mock = _make_k8s_mock({}) + from shekel.integrations.kubernetes import KubernetesSpendReporter + + r = KubernetesSpendReporter("b", "default") + r._total_spent = 0.42 + r._total_calls = 3 + + with _flush_env(k8s_mock, hostname="my-pod"): + r._flush() + + api = k8s_mock.client.CoreV1Api.return_value + api.patch_namespaced_config_map.assert_called_once() + api.create_namespaced_config_map.assert_not_called() + + def test_flush_creates_on_404(self) -> None: + k8s_mock = _make_k8s_mock({}) + FakeApiException = _make_api_exception_class(k8s_mock) + api = k8s_mock.client.CoreV1Api.return_value + api.patch_namespaced_config_map.side_effect = FakeApiException(404) + + from shekel.integrations.kubernetes import KubernetesSpendReporter + + r = KubernetesSpendReporter("b", "default") + r._total_spent = 0.10 + + with _flush_env(k8s_mock, hostname="my-pod"): + r._flush() + + api.create_namespaced_config_map.assert_called_once() + + def test_flush_non_404_api_exception_logged_as_warning( + self, caplog: pytest.LogCaptureFixture + ) -> None: + import logging + + k8s_mock = _make_k8s_mock({}) + FakeApiException = _make_api_exception_class(k8s_mock) + api = k8s_mock.client.CoreV1Api.return_value + api.patch_namespaced_config_map.side_effect = FakeApiException(503) + + from shekel.integrations.kubernetes import KubernetesSpendReporter + + r = KubernetesSpendReporter("b", "default") + r._total_spent = 0.10 + + with caplog.at_level(logging.WARNING, logger="shekel.integrations.kubernetes"): + with _flush_env(k8s_mock, hostname="my-pod"): + r._flush() # must not raise + + assert any("Failed to flush" in rec.message for rec in caplog.records) + + def test_flush_logs_warning_on_generic_failure(self, caplog: pytest.LogCaptureFixture) -> None: + import logging + + k8s_mock = _make_k8s_mock({}) + api = k8s_mock.client.CoreV1Api.return_value + api.patch_namespaced_config_map.side_effect = OSError("connection refused") + + from shekel.integrations.kubernetes import KubernetesSpendReporter + + r = KubernetesSpendReporter("b", "default") + r._total_spent = 0.10 + + with caplog.at_level(logging.WARNING, logger="shekel.integrations.kubernetes"): + with _flush_env(k8s_mock, hostname="my-pod"): + r._flush() # must not raise + + assert any("Failed to flush" in rec.message for rec in caplog.records) + + def test_flush_does_not_raise_on_failure(self) -> None: + k8s_mock = _make_k8s_mock({}) + api = k8s_mock.client.CoreV1Api.return_value + api.patch_namespaced_config_map.side_effect = OSError("network error") + + from shekel.integrations.kubernetes import KubernetesSpendReporter + + r = KubernetesSpendReporter("b", "default") + r._total_spent = 0.10 + + with _flush_env(k8s_mock, hostname="my-pod"): + r._flush() # must not raise + + # ── _flush() — ConfigMap body ────────────────────────────────────────── + + def test_flush_configmap_name_and_namespace(self) -> None: + k8s_mock = _make_k8s_mock({}) + from shekel.integrations.kubernetes import KubernetesSpendReporter + + r = KubernetesSpendReporter("my-budget", "production") + r._total_spent = 0.10 + + with _flush_env(k8s_mock, hostname="worker-pod-7"): + r._flush() + + api = k8s_mock.client.CoreV1Api.return_value + args = api.patch_namespaced_config_map.call_args[0] + assert args[0] == "shekel-spend-worker-pod-7" + assert args[1] == "production" + + def test_flush_configmap_labels_correct(self) -> None: + k8s_mock = _make_k8s_mock({}) + from shekel.integrations.kubernetes import KubernetesSpendReporter + + r = KubernetesSpendReporter("my-budget", "default", group_value="team-a") + r._total_spent = 0.10 + + with _flush_env(k8s_mock, hostname="pod-1"): + r._flush() + + body = k8s_mock.client.CoreV1Api.return_value.patch_namespaced_config_map.call_args[0][2] + labels = body["metadata"]["labels"] + assert labels["shekel.dev/spend-report"] == "true" + assert labels["shekel.dev/budget"] == "my-budget" + assert labels["shekel.dev/group"] == "team-a" + + def test_flush_group_label_omitted_when_empty(self) -> None: + k8s_mock = _make_k8s_mock({}) + from shekel.integrations.kubernetes import KubernetesSpendReporter + + r = KubernetesSpendReporter("b", "default", group_value="") + r._total_spent = 0.10 + + with _flush_env(k8s_mock, hostname="pod-1"): + r._flush() + + body = k8s_mock.client.CoreV1Api.return_value.patch_namespaced_config_map.call_args[0][2] + assert "shekel.dev/group" not in body["metadata"]["labels"] + + def test_flush_configmap_data_fields(self) -> None: + k8s_mock = _make_k8s_mock({}) + from shekel.integrations.kubernetes import KubernetesSpendReporter + + r = KubernetesSpendReporter("b", "default") + r._total_spent = 1.23 + r._total_calls = 7 + + with _flush_env(k8s_mock, hostname="test-pod"): + r._flush() + + data = k8s_mock.client.CoreV1Api.return_value.patch_namespaced_config_map.call_args[0][2][ + "data" + ] + assert data["spent_usd"] == "1.23" + assert data["call_count"] == "7" + assert data["pod_name"] == "test-pod" + assert "T" in data["last_updated"] and data["last_updated"].endswith("Z") + + def test_flush_writes_cumulative_total_not_delta(self) -> None: + k8s_mock = _make_k8s_mock({}) + from shekel.integrations.kubernetes import KubernetesSpendReporter + + r = KubernetesSpendReporter("b", "default") + r._last_flush_spent = 0.50 + r._total_spent = 0.75 # delta = 0.25, but we expect cumulative 0.75 + r._total_calls = 5 + + with _flush_env(k8s_mock, hostname="pod-1"): + r._flush() + + body = k8s_mock.client.CoreV1Api.return_value.patch_namespaced_config_map.call_args[0][2] + assert body["data"]["spent_usd"] == "0.75" + + # ── _flush() — baseline tracking ────────────────────────────────────── + + def test_flush_updates_last_flush_spent_on_success(self) -> None: + k8s_mock = _make_k8s_mock({}) + from shekel.integrations.kubernetes import KubernetesSpendReporter + + r = KubernetesSpendReporter("b", "default") + r._total_spent = 0.75 + r._total_calls = 5 + + with _flush_env(k8s_mock, hostname="pod-1"): + r._flush() + + assert r._last_flush_spent == pytest.approx(0.75) + + def test_flush_does_not_update_baseline_on_failure(self) -> None: + """After a failed flush, _last_flush_spent is unchanged for retry with full total.""" + k8s_mock = _make_k8s_mock({}) + api = k8s_mock.client.CoreV1Api.return_value + api.patch_namespaced_config_map.side_effect = OSError("network error") + + from shekel.integrations.kubernetes import KubernetesSpendReporter + + r = KubernetesSpendReporter("b", "default") + r._total_spent = 0.75 + r._last_flush_spent = 0.50 + + with _flush_env(k8s_mock, hostname="pod-1"): + r._flush() + + assert r._last_flush_spent == pytest.approx(0.50) + + # ── stop / flush_and_stop ────────────────────────────────────────────── + + def test_stop_sets_stop_event(self) -> None: + from shekel.integrations.kubernetes import KubernetesSpendReporter + + r = KubernetesSpendReporter("b", "ns") + assert not r._stop_event.is_set() + r.stop() + assert r._stop_event.is_set() + + def test_flush_and_stop_calls_flush(self) -> None: + from shekel.integrations.kubernetes import KubernetesSpendReporter + + r = KubernetesSpendReporter("b", "ns") + with patch.object(r, "_flush") as mock_flush: + r.flush_and_stop() + mock_flush.assert_called_once() + + def test_flush_and_stop_sets_stop_event(self) -> None: + from shekel.integrations.kubernetes import KubernetesSpendReporter + + r = KubernetesSpendReporter("b", "ns") + with patch.object(r, "_flush"): + r.flush_and_stop() + assert r._stop_event.is_set() + + # ── Budget._record_spend integration ────────────────────────────────── + + def test_record_spend_notifies_reporter(self) -> None: + b = _budget_with_k8s({"backend": "k8s"}, budget_kwargs={"max_usd": 10.0}) + assert b._k8s_reporter is not None + + with patch.object(b._k8s_reporter, "on_spend") as mock_on_spend: + b._record_spend(0.05, "gpt-4", {"input": 100, "output": 50}) + + mock_on_spend.assert_called_once_with(0.05) + + def test_record_spend_does_not_notify_when_no_reporter(self) -> None: + """Budget without K8s env doesn't have a reporter — no crash.""" + import os + + from shekel._budget import Budget + + env = { + k: v + for k, v in os.environ.items() + if k not in ("KUBERNETES_SERVICE_HOST", "SHEKEL_BUDGET_NAME") + } + with patch.dict("os.environ", env, clear=True): + b = Budget(max_usd=10.0) + + assert b._k8s_reporter is None + b._record_spend(0.05, "gpt-4", {"input": 100, "output": 50}) # must not raise + + # ── Thread safety ────────────────────────────────────────────────────── + + def test_concurrent_on_spend_calls_thread_safe(self) -> None: + from shekel.integrations.kubernetes import KubernetesSpendReporter + + r = KubernetesSpendReporter("b", "ns", flush_every_usd=None) + + def spend_batch() -> None: + for _ in range(100): + r.on_spend(0.01) + + threads = [threading.Thread(target=spend_batch) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert r._total_spent == pytest.approx(10.0) + assert r._total_calls == 1000 From f2cb46e90e1eba070ab7378c4a20d1cf8d38dfc3 Mon Sep 17 00:00:00 2001 From: Arie Radle Date: Wed, 10 Jun 2026 08:16:33 +0000 Subject: [PATCH 03/12] =?UTF-8?q?fix:=20SHEK-26=20=E2=80=94=20replace=20pe?= =?UTF-8?q?r-pod=20cap=20nested=20Budget=20with=20float=20to=20eliminate?= =?UTF-8?q?=20infinite=20recursion?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Constructing a child Budget for per_pod_cap caused unbounded recursion because KUBERNETES_SERVICE_HOST/SHEKEL_BUDGET_NAME are still set in the process, triggering apply_k8s_config() on every child __init__. Store the cap as _per_pod_cap_usd: float | None instead, and enforce it via a new _check_per_pod_limit() method called inside _record_spend(). Adds 4 regression tests including an explicit no-recursion guard. Co-Authored-By: Claude Sonnet 4.6 --- shekel/_budget.py | 14 +++++++++++++- shekel/integrations/kubernetes.py | 4 +--- tests/test_kubernetes_integration.py | 29 +++++++++++++++++++++++++--- 3 files changed, 40 insertions(+), 7 deletions(-) diff --git a/shekel/_budget.py b/shekel/_budget.py index 48a64b4..0a04494 100644 --- a/shekel/_budget.py +++ b/shekel/_budget.py @@ -304,7 +304,7 @@ def __init__( self._paused_externally: bool = False self._k8s_poller: Any = None self._k8s_reporter: Any = None - self._per_pod_budget: Any = None + self._per_pod_cap_usd: float | None = None self._k8s_redis_backend: Any = None self._k8s_redis_name: str | None = None self._k8s_flush_every_usd: float | None = None @@ -700,6 +700,7 @@ def _record_spend(self, cost: float, model: str, tokens: dict[str, int]) -> None self._check_warn() self._check_limit() self._check_call_limit() + self._check_per_pod_limit() if self._k8s_reporter is not None: self._k8s_reporter.on_spend(cost) @@ -810,6 +811,17 @@ def _check_call_limit(self) -> None: self._last_tokens, ) + def _check_per_pod_limit(self) -> None: + """Enforce the per-pod USD cap set via ConfigMap per_pod_cap (SHEK-26).""" + if self._per_pod_cap_usd is None: + return + if self._spent > self._per_pod_cap_usd: + from shekel.exceptions import BudgetExceededError # noqa: PLC0415 + + raise BudgetExceededError( + self._spent, self._per_pod_cap_usd, self._last_model, self._last_tokens + ) + # ------------------------------------------------------------------ # Loop guard enforcement (v1.1.0) # ------------------------------------------------------------------ diff --git a/shekel/integrations/kubernetes.py b/shekel/integrations/kubernetes.py index 8706926..9556814 100644 --- a/shekel/integrations/kubernetes.py +++ b/shekel/integrations/kubernetes.py @@ -111,9 +111,7 @@ def apply_k8s_config(budget: Budget) -> None: # --- per_pod_cap --- if "per_pod_cap" in cm: - from shekel._budget import Budget as _Budget # noqa: PLC0415 - - budget._per_pod_budget = _Budget(max_usd=float(cm["per_pod_cap"])) + budget._per_pod_cap_usd = float(cm["per_pod_cap"]) # --- Redis backend --- if cm.get("backend") == "redis": diff --git a/tests/test_kubernetes_integration.py b/tests/test_kubernetes_integration.py index fb3eb79..8b07240 100644 --- a/tests/test_kubernetes_integration.py +++ b/tests/test_kubernetes_integration.py @@ -570,10 +570,33 @@ def test_redis_backend_skipped_without_redis_url(self) -> None: class TestPerPodCap: - def test_per_pod_cap_stored_on_budget(self) -> None: + def test_per_pod_cap_stored_as_float(self) -> None: b = _budget_with_k8s({"per_pod_cap": "0.25"}) - assert hasattr(b, "_per_pod_budget") - assert b._per_pod_budget.max_usd == pytest.approx(0.25) + assert b._per_pod_cap_usd == pytest.approx(0.25) + + def test_per_pod_cap_does_not_recurse(self) -> None: + # Regression for SHEK-26: constructing a Budget with per_pod_cap in the + # ConfigMap must not trigger infinite recursion via nested Budget.__init__ calls. + b = _budget_with_k8s({"per_pod_cap": "0.10"}) + assert b._per_pod_cap_usd == pytest.approx(0.10) + assert not hasattr(b, "_per_pod_budget") + + def test_per_pod_cap_enforced_on_exceed(self) -> None: + from shekel.exceptions import BudgetExceededError + + b = _budget_with_k8s({"per_pod_cap": "0.05"}) + with b: + b._record_spend(0.03, "gpt-4o", {"input": 100, "output": 50}) # under cap — ok + with pytest.raises(BudgetExceededError) as exc_info: + b._record_spend(0.03, "gpt-4o", {"input": 100, "output": 50}) # exceeds cap + assert exc_info.value.limit == pytest.approx(0.05) + + def test_per_pod_cap_not_enforced_when_absent(self) -> None: + # No per_pod_cap in ConfigMap → spending freely past any cap value must not raise. + b = _budget_with_k8s({}) + with b: + b._record_spend(0.50, "gpt-4o", {"input": 100, "output": 50}) + b._record_spend(0.50, "gpt-4o", {"input": 100, "output": 50}) # --------------------------------------------------------------------------- From a1d0d5d5cd2dd148f14dce0a18705d118501a821 Mon Sep 17 00:00:00 2001 From: Arie Radle Date: Wed, 10 Jun 2026 08:34:29 +0000 Subject: [PATCH 04/12] =?UTF-8?q?fix:=20SHEK-27=20=E2=80=94=20restart=20K8?= =?UTF-8?q?s=20poller/reporter=20threads=20on=20Budget=20re-entry?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit After __exit__ stops the K8s threads, re-entering the same Budget instance (session budget pattern) left K8s polling and spend reporting permanently dead. Python threads cannot be restarted, so a new instance must be created. Persist budget_name, namespace, and poll_interval on the budget during apply_k8s_config(), then call _restart_k8s_threads() from __enter__ and __aenter__ to rebuild any stopped threads idempotently. Adds 6 regression tests including sync and async re-entry, idempotency, and no-K8s no-op. Co-Authored-By: Claude Sonnet 4.6 --- shekel/_budget.py | 41 +++++++++++++++ shekel/integrations/kubernetes.py | 3 ++ tests/test_kubernetes_integration.py | 75 ++++++++++++++++++++++++++++ 3 files changed, 119 insertions(+) diff --git a/shekel/_budget.py b/shekel/_budget.py index 0a04494..f6ac1db 100644 --- a/shekel/_budget.py +++ b/shekel/_budget.py @@ -311,6 +311,9 @@ def __init__( self._k8s_flush_every_seconds: float | None = None self._k8s_scope_mode: str | None = None self._k8s_scope_group_by: str | None = None + self._k8s_budget_name: str | None = None + self._k8s_namespace: str | None = None + self._k8s_poll_interval: float = 10.0 try: from shekel.integrations.kubernetes import apply_k8s_config # noqa: PLC0415 @@ -434,6 +437,7 @@ def __enter__(self) -> Budget: self._runtime = ShekelRuntime(self) self._runtime.probe() + self._restart_k8s_threads() return self def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> None: @@ -579,6 +583,7 @@ async def __aenter__(self) -> Budget: self._runtime = ShekelRuntime(self) self._runtime.probe() + self._restart_k8s_threads() return self async def __aexit__(self, exc_type: object, exc_val: object, exc_tb: object) -> None: @@ -823,6 +828,42 @@ def _check_per_pod_limit(self) -> None: ) # ------------------------------------------------------------------ + def _restart_k8s_threads(self) -> None: + """Restart stopped K8s poller/reporter on Budget re-entry (SHEK-27). + + Called by __enter__ / __aenter__. No-op when not in K8s mode or when + threads are still alive (idempotent on nested/repeated enters). + """ + if self._k8s_budget_name is None: + return + + if self._k8s_poller is not None and not self._k8s_poller.is_alive(): + from shekel.integrations.kubernetes import KubernetesPoller # noqa: PLC0415 + + poller = KubernetesPoller( + self, + self._k8s_budget_name, + self._k8s_namespace or "", + self._k8s_poll_interval, + ) + poller.start() + self._k8s_poller = poller + + if self._k8s_reporter is not None and not self._k8s_reporter.is_alive(): + import os # noqa: PLC0415 + + from shekel.integrations.kubernetes import KubernetesSpendReporter # noqa: PLC0415 + + reporter = KubernetesSpendReporter( + budget_name=self._k8s_budget_name, + namespace=self._k8s_namespace or "", + flush_every_seconds=self._k8s_flush_every_seconds or 60.0, + flush_every_usd=self._k8s_flush_every_usd, + group_value=os.environ.get("SHEKEL_GROUP_VALUE", ""), + ) + reporter.start() + self._k8s_reporter = reporter + # Loop guard enforcement (v1.1.0) # ------------------------------------------------------------------ diff --git a/shekel/integrations/kubernetes.py b/shekel/integrations/kubernetes.py index 9556814..01be4a2 100644 --- a/shekel/integrations/kubernetes.py +++ b/shekel/integrations/kubernetes.py @@ -77,6 +77,8 @@ def apply_k8s_config(budget: Budget) -> None: budget_name = os.environ["SHEKEL_BUDGET_NAME"] namespace = _read_namespace() + budget._k8s_budget_name = budget_name + budget._k8s_namespace = namespace cm = _fetch_configmap(budget_name, namespace) if cm is None: @@ -150,6 +152,7 @@ def apply_k8s_config(budget: Budget) -> None: # --- Start background kill-switch poller --- interval = float(os.environ.get("SHEKEL_POLL_INTERVAL_SECONDS", "10")) + budget._k8s_poll_interval = interval poller = KubernetesPoller(budget, budget_name, namespace, interval) poller.start() budget._k8s_poller = poller diff --git a/tests/test_kubernetes_integration.py b/tests/test_kubernetes_integration.py index 8b07240..065bcc3 100644 --- a/tests/test_kubernetes_integration.py +++ b/tests/test_kubernetes_integration.py @@ -599,6 +599,81 @@ def test_per_pod_cap_not_enforced_when_absent(self) -> None: b._record_spend(0.50, "gpt-4o", {"input": 100, "output": 50}) +# --------------------------------------------------------------------------- +# SHEK-27: poller/reporter restart on Budget re-entry +# --------------------------------------------------------------------------- + + +class TestPollerRestart: + def test_poller_restarts_on_re_entry(self) -> None: + # Regression for SHEK-27: re-entering a session budget must spawn a new + # live poller thread — the old one was stopped on __exit__. + b = _budget_with_k8s({"max_usd": "1.00"}) + with b: + first_poller = b._k8s_poller + first_poller.join(timeout=2.0) # wait for thread to actually die after stop() + assert not first_poller.is_alive() + with b: + assert b._k8s_poller is not first_poller + assert b._k8s_poller.is_alive() + + def test_reporter_restarts_on_re_entry(self) -> None: + b = _budget_with_k8s({"max_usd": "1.00", "backend": "k8s"}) + with b: + first_reporter = b._k8s_reporter + first_reporter.join(timeout=2.0) # wait for thread to actually die + assert not first_reporter.is_alive() + with b: + assert b._k8s_reporter is not first_reporter + assert b._k8s_reporter.is_alive() + + def test_restart_idempotent_when_thread_alive(self) -> None: + # Calling _restart_k8s_threads() while the thread is still alive must not + # create a duplicate — the same instance should be kept. + b = _budget_with_k8s({"max_usd": "1.00"}) + with b: + poller_id = id(b._k8s_poller) + b._restart_k8s_threads() + assert id(b._k8s_poller) == poller_id + + def test_restart_noop_outside_k8s(self) -> None: + from shekel._budget import Budget + + b = Budget(max_usd=1.0) + b._restart_k8s_threads() # must not raise + assert b._k8s_poller is None + assert b._k8s_reporter is None + + async def test_async_poller_restarts_on_re_entry(self) -> None: + # Same as test_poller_restarts_on_re_entry but via async with — covers + # __aenter__ / __aexit__ paths including _restart_k8s_threads() call. + b = _budget_with_k8s({"max_usd": "1.00"}) + async with b: + first_poller = b._k8s_poller + first_poller.join(timeout=2.0) + assert not first_poller.is_alive() + async with b: + assert b._k8s_poller is not first_poller + assert b._k8s_poller.is_alive() + + def test_exit_exceeded_status_with_k8s_poller(self) -> None: + # Exercises the "exceeded" exit-status branch in __exit__ while K8s threads are active. + from shekel.exceptions import BudgetExceededError + + b = _budget_with_k8s({"max_usd": "0.01"}) + with pytest.raises(BudgetExceededError): + with b: + b._record_spend(0.02, "gpt-4o", {"input": 10, "output": 5}) + + def test_exit_warned_status_with_k8s_poller(self) -> None: + # Exercises the "warned" exit-status branch in __exit__ while K8s threads are active. + warned = [] + b = _budget_with_k8s({"max_usd": "1.00", "warn_at": "0.5"}) + b.on_warn = lambda spent, limit: warned.append(spent) + with b: + b._record_spend(0.60, "gpt-4o", {"input": 10, "output": 5}) + + # --------------------------------------------------------------------------- # SHEK-17 fields stored # --------------------------------------------------------------------------- From 5751881e5b314e2be6eb4d5b6befd3f81a01b367 Mon Sep 17 00:00:00 2001 From: Arie Radle Date: Wed, 10 Jun 2026 08:41:29 +0000 Subject: [PATCH 05/12] =?UTF-8?q?fix:=20SHEK-28=20=E2=80=94=20log=20K8s=20?= =?UTF-8?q?config=20failures=20instead=20of=20swallowing=20them?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Blanket except Exception: pass hid bad ConfigMap values, recursion bugs, and any other apply_k8s_config failure with no log or signal. Split into ImportError (silent — optional dep not installed) and Exception (warning with exc_info so operators can diagnose misconfigured ConfigMaps). Co-Authored-By: Claude Sonnet 4.6 --- shekel/_budget.py | 9 +++++- tests/test_kubernetes_integration.py | 45 ++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/shekel/_budget.py b/shekel/_budget.py index f6ac1db..4878e70 100644 --- a/shekel/_budget.py +++ b/shekel/_budget.py @@ -318,8 +318,15 @@ def __init__( from shekel.integrations.kubernetes import apply_k8s_config # noqa: PLC0415 apply_k8s_config(self) + except ImportError: + pass # kubernetes optional dependency not installed except Exception: - pass + import logging # noqa: PLC0415 + + logging.getLogger(__name__).warning( + "shekel[k8s]: Failed to apply Kubernetes config; K8s features disabled.", + exc_info=True, + ) # ------------------------------------------------------------------ # Internal state reset diff --git a/tests/test_kubernetes_integration.py b/tests/test_kubernetes_integration.py index 065bcc3..9d53552 100644 --- a/tests/test_kubernetes_integration.py +++ b/tests/test_kubernetes_integration.py @@ -764,6 +764,51 @@ def test_redis_import_error_logs_warning(self, caplog: pytest.LogCaptureFixture) assert any("redis" in r.message.lower() for r in caplog.records) + def test_apply_k8s_config_exception_logs_warning( + self, caplog: pytest.LogCaptureFixture + ) -> None: + # Regression for SHEK-28: non-ImportError from apply_k8s_config must be + # logged as a warning so operators can diagnose misconfigured ConfigMaps. + import logging + from unittest.mock import patch + + from shekel._budget import Budget + + with patch( + "shekel.integrations.kubernetes.apply_k8s_config", + side_effect=ValueError("bad value"), + ): + with patch.dict( + "os.environ", + {"KUBERNETES_SERVICE_HOST": "10.0.0.1", "SHEKEL_BUDGET_NAME": "test-budget"}, + ): + with caplog.at_level(logging.WARNING, logger="shekel._budget"): + Budget() + + assert any("K8s features disabled" in r.message for r in caplog.records) + + def test_apply_k8s_config_import_error_is_silent( + self, caplog: pytest.LogCaptureFixture + ) -> None: + # ImportError (optional dependency not installed) must remain silent. + import logging + from unittest.mock import patch + + from shekel._budget import Budget + + with patch( + "shekel.integrations.kubernetes.apply_k8s_config", + side_effect=ImportError("no module"), + ): + with patch.dict( + "os.environ", + {"KUBERNETES_SERVICE_HOST": "10.0.0.1", "SHEKEL_BUDGET_NAME": "test-budget"}, + ): + with caplog.at_level(logging.WARNING, logger="shekel._budget"): + Budget() + + assert not any("K8s features disabled" in r.message for r in caplog.records) + class TestShek17Fields: def test_flush_every_usd_stored(self) -> None: From 02354fd3be223986758106a156df8d9ad6242795 Mon Sep 17 00:00:00 2001 From: Arie Radle Date: Wed, 10 Jun 2026 09:23:03 +0000 Subject: [PATCH 06/12] =?UTF-8?q?fix:=20SHEK-33=20=E2=80=94=20=5Fcheck=5Fp?= =?UTF-8?q?er=5Fpod=5Flimit()=20now=20respects=20warn=5Fonly=20mode?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Sonnet 4.6 --- shekel/_budget.py | 4 ++++ tests/test_kubernetes_integration.py | 8 ++++++++ 2 files changed, 12 insertions(+) diff --git a/shekel/_budget.py b/shekel/_budget.py index 4878e70..54114a6 100644 --- a/shekel/_budget.py +++ b/shekel/_budget.py @@ -828,6 +828,10 @@ def _check_per_pod_limit(self) -> None: if self._per_pod_cap_usd is None: return if self._spent > self._per_pod_cap_usd: + self._emit_budget_exceeded_event() + if self.warn_only: + self._check_warn() + return from shekel.exceptions import BudgetExceededError # noqa: PLC0415 raise BudgetExceededError( diff --git a/tests/test_kubernetes_integration.py b/tests/test_kubernetes_integration.py index 9d53552..7762174 100644 --- a/tests/test_kubernetes_integration.py +++ b/tests/test_kubernetes_integration.py @@ -598,6 +598,14 @@ def test_per_pod_cap_not_enforced_when_absent(self) -> None: b._record_spend(0.50, "gpt-4o", {"input": 100, "output": 50}) b._record_spend(0.50, "gpt-4o", {"input": 100, "output": 50}) + def test_per_pod_cap_warn_only_does_not_raise(self) -> None: + # SHEK-33: warn_only=True must suppress the raise even when per_pod_cap is exceeded. + b = _budget_with_k8s({"per_pod_cap": "0.05"}, budget_kwargs={"warn_only": True}) + with b: + b._record_spend(0.03, "gpt-4o", {"input": 10, "output": 5}) + b._record_spend(0.03, "gpt-4o", {"input": 10, "output": 5}) # exceeds cap — must not raise + assert b._spent == pytest.approx(0.06) + # --------------------------------------------------------------------------- # SHEK-27: poller/reporter restart on Budget re-entry From 8c203122585f23ae66c1ab0ac811b7b2a75d9a18 Mon Sep 17 00:00:00 2001 From: Arie Radle Date: Wed, 10 Jun 2026 09:31:43 +0000 Subject: [PATCH 07/12] =?UTF-8?q?fix:=20SHEK-32=20=E2=80=94=20report=20spe?= =?UTF-8?q?nd=20before=20enforcement=20checks=20so=20reporter=20captures?= =?UTF-8?q?=20limit-exceeding=20call=20cost?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Sonnet 4.6 --- shekel/_budget.py | 4 ++-- tests/test_kubernetes_integration.py | 20 ++++++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/shekel/_budget.py b/shekel/_budget.py index 54114a6..52973bc 100644 --- a/shekel/_budget.py +++ b/shekel/_budget.py @@ -707,14 +707,14 @@ def _record_spend(self, cost: float, model: str, tokens: dict[str, int]) -> None ) self._calls_made += 1 self._append_velocity_entry(cost) + if self._k8s_reporter is not None: + self._k8s_reporter.on_spend(cost) self._check_velocity_warn() self._check_velocity_limit() self._check_warn() self._check_limit() self._check_call_limit() self._check_per_pod_limit() - if self._k8s_reporter is not None: - self._k8s_reporter.on_spend(cost) def _check_warn(self) -> None: effective_limit = self._effective_limit diff --git a/tests/test_kubernetes_integration.py b/tests/test_kubernetes_integration.py index 7762174..6005dcd 100644 --- a/tests/test_kubernetes_integration.py +++ b/tests/test_kubernetes_integration.py @@ -1331,6 +1331,26 @@ def test_record_spend_does_not_notify_when_no_reporter(self) -> None: assert b._k8s_reporter is None b._record_spend(0.05, "gpt-4", {"input": 100, "output": 50}) # must not raise + # ── Budget integration ───────────────────────────────────────────────── + + def test_reporter_sees_cost_of_limit_exceeding_call(self) -> None: + # SHEK-32: on_spend must be called before enforcement checks so the + # reporter captures the cost of the call that triggers BudgetExceededError. + from shekel import Budget + from shekel.exceptions import BudgetExceededError + from shekel.integrations.kubernetes import KubernetesSpendReporter + + reporter = KubernetesSpendReporter("b", "ns") + b = Budget(max_usd=0.05) + b._k8s_reporter = reporter + + with b: + b._record_spend(0.03, "gpt-4o", {"input": 10, "output": 5}) + with pytest.raises(BudgetExceededError): + b._record_spend(0.03, "gpt-4o", {"input": 10, "output": 5}) + + assert reporter._total_spent == pytest.approx(0.06) + # ── Thread safety ────────────────────────────────────────────────────── def test_concurrent_on_spend_calls_thread_safe(self) -> None: From 00e71c8e7aedbb7d10b49c86f7924422ec142f50 Mon Sep 17 00:00:00 2001 From: Arie Radle Date: Wed, 10 Jun 2026 09:41:44 +0000 Subject: [PATCH 08/12] =?UTF-8?q?fix:=20SHEK-31=20=E2=80=94=20stop=20leake?= =?UTF-8?q?d=20K8s=20threads=20after=20each=20test=20via=20autouse=20fixtu?= =?UTF-8?q?re;=20fix=20black=20line-length?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Sonnet 4.6 --- tests/test_kubernetes_integration.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/tests/test_kubernetes_integration.py b/tests/test_kubernetes_integration.py index 6005dcd..e506091 100644 --- a/tests/test_kubernetes_integration.py +++ b/tests/test_kubernetes_integration.py @@ -8,6 +8,7 @@ import contextlib import sys import threading +from collections.abc import Iterator from typing import Any from unittest.mock import MagicMock, patch @@ -18,6 +19,14 @@ # --------------------------------------------------------------------------- +@pytest.fixture(autouse=True) +def _stop_k8s_threads() -> Iterator[None]: + yield + for thread in threading.enumerate(): + if thread.name.startswith("shekel-k8s-") and hasattr(thread, "stop"): + thread.stop() + + def _make_configmap(data: dict[str, str]) -> MagicMock: cm = MagicMock() cm.data = data @@ -603,7 +612,9 @@ def test_per_pod_cap_warn_only_does_not_raise(self) -> None: b = _budget_with_k8s({"per_pod_cap": "0.05"}, budget_kwargs={"warn_only": True}) with b: b._record_spend(0.03, "gpt-4o", {"input": 10, "output": 5}) - b._record_spend(0.03, "gpt-4o", {"input": 10, "output": 5}) # exceeds cap — must not raise + b._record_spend( + 0.03, "gpt-4o", {"input": 10, "output": 5} + ) # exceeds cap — must not raise assert b._spent == pytest.approx(0.06) From ecd0b9e93019553816c43c24809c2e042fc2fc78 Mon Sep 17 00:00:00 2001 From: Arie Radle Date: Wed, 10 Jun 2026 09:47:19 +0000 Subject: [PATCH 09/12] =?UTF-8?q?feat:=20SHEK-29=20=E2=80=94=20add=20Budge?= =?UTF-8?q?tPausedError=20subclass=20to=20distinguish=20operator=20kill-sw?= =?UTF-8?q?itch=20from=20budget=20exhaustion?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Sonnet 4.6 --- shekel/__init__.py | 2 ++ shekel/_budget.py | 4 ++-- shekel/exceptions.py | 14 ++++++++++++++ tests/test_kubernetes_integration.py | 19 +++++++++++++++++-- 4 files changed, 35 insertions(+), 4 deletions(-) diff --git a/shekel/__init__.py b/shekel/__init__.py index e3d446c..1181e53 100644 --- a/shekel/__init__.py +++ b/shekel/__init__.py @@ -10,6 +10,7 @@ AgentLoopError, BudgetConfigMismatchError, BudgetExceededError, + BudgetPausedError, ChainBudgetExceededError, NodeBudgetExceededError, SessionBudgetExceededError, @@ -25,6 +26,7 @@ "TemporalBudget", "with_budget", "BudgetExceededError", + "BudgetPausedError", "BudgetConfigMismatchError", "ToolBudgetExceededError", "NodeBudgetExceededError", diff --git a/shekel/_budget.py b/shekel/_budget.py index 52973bc..2252922 100644 --- a/shekel/_budget.py +++ b/shekel/_budget.py @@ -673,9 +673,9 @@ def reset(self) -> None: def _record_spend(self, cost: float, model: str, tokens: dict[str, int]) -> None: if self._paused_externally: - from shekel.exceptions import BudgetExceededError # noqa: PLC0415 + from shekel.exceptions import BudgetPausedError # noqa: PLC0415 - raise BudgetExceededError( + raise BudgetPausedError( spent=self._spent, limit=self.max_usd or 0.0, model=model, diff --git a/shekel/exceptions.py b/shekel/exceptions.py index 6bda3d6..f5adfe5 100644 --- a/shekel/exceptions.py +++ b/shekel/exceptions.py @@ -92,6 +92,20 @@ def __str__(self) -> str: ) +class BudgetPausedError(BudgetExceededError): + """Raised when an operator has paused the budget via Kubernetes ConfigMap kill-switch. + + Subclasses BudgetExceededError so existing except-clauses catch it without changes. + """ + + def __str__(self) -> str: + return ( + f"Budget paused by operator (${self.spent:.4f} spent)\n" + f" Last call: {self.model}\n" + f" Tip: Set paused=false in the shekel-budget ConfigMap to resume." + ) + + class NodeBudgetExceededError(BudgetExceededError): """Raised when a LangGraph node exceeds its budget cap. diff --git a/tests/test_kubernetes_integration.py b/tests/test_kubernetes_integration.py index e506091..6f8a30c 100644 --- a/tests/test_kubernetes_integration.py +++ b/tests/test_kubernetes_integration.py @@ -319,17 +319,32 @@ def test_paused_missing_does_not_set_flag(self) -> None: def test_paused_budget_raises_on_record_spend(self) -> None: from shekel._budget import Budget - from shekel.exceptions import BudgetExceededError + from shekel.exceptions import BudgetPausedError b = Budget(max_usd=1.00) b._paused_externally = True - with pytest.raises(BudgetExceededError): + with pytest.raises(BudgetPausedError): b._record_spend(0.01, "gpt-4o-mini", {"input": 10, "output": 5}) # Paused check fires before spend is accumulated assert b._spent == pytest.approx(0.0) + def test_paused_error_is_subclass_of_budget_exceeded_error(self) -> None: + from shekel.exceptions import BudgetExceededError, BudgetPausedError + + assert issubclass(BudgetPausedError, BudgetExceededError) + + def test_limit_exceeded_is_not_paused_error(self) -> None: + from shekel._budget import Budget + from shekel.exceptions import BudgetExceededError, BudgetPausedError + + b = Budget(max_usd=0.01) + with b: + with pytest.raises(BudgetExceededError) as exc_info: + b._record_spend(0.05, "gpt-4o", {"input": 10, "output": 5}) + assert not isinstance(exc_info.value, BudgetPausedError) + def test_not_paused_does_not_raise_on_record_spend(self) -> None: from shekel._budget import Budget From 397fd2251180f3f93e7a3a3261ad4f3cb9ef4e7d Mon Sep 17 00:00:00 2001 From: Arie Radle Date: Wed, 10 Jun 2026 12:45:20 +0000 Subject: [PATCH 10/12] =?UTF-8?q?fix:=20SHEK-30=20=E2=80=94=20wire=20=5Fk8?= =?UTF-8?q?s=5Fredis=5Fbackend=20into=20=5Frecord=5Fspend()=20for=20distri?= =?UTF-8?q?buted=20enforcement?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Sonnet 4.6 --- shekel/_budget.py | 27 ++++++++++++++++ shekel/integrations/kubernetes.py | 1 + tests/test_kubernetes_integration.py | 47 ++++++++++++++++++++++++++++ 3 files changed, 75 insertions(+) diff --git a/shekel/_budget.py b/shekel/_budget.py index 2252922..970a7cc 100644 --- a/shekel/_budget.py +++ b/shekel/_budget.py @@ -307,6 +307,7 @@ def __init__( self._per_pod_cap_usd: float | None = None self._k8s_redis_backend: Any = None self._k8s_redis_name: str | None = None + self._k8s_redis_window_seconds: float = 86400.0 self._k8s_flush_every_usd: float | None = None self._k8s_flush_every_seconds: float | None = None self._k8s_scope_mode: str | None = None @@ -715,6 +716,7 @@ def _record_spend(self, cost: float, model: str, tokens: dict[str, int]) -> None self._check_limit() self._check_call_limit() self._check_per_pod_limit() + self._check_redis_limit(cost) def _check_warn(self) -> None: effective_limit = self._effective_limit @@ -838,6 +840,31 @@ def _check_per_pod_limit(self) -> None: self._spent, self._per_pod_cap_usd, self._last_model, self._last_tokens ) + def _check_redis_limit(self, cost: float) -> None: + """Distributed enforcement via Redis backend (SHEK-30).""" + if self._k8s_redis_backend is None or self._k8s_redis_name is None: + return + allowed, exceeded = self._k8s_redis_backend.check_and_add( + self._k8s_redis_name, + {"usd": cost}, + {"usd": self.max_usd}, + {"usd": self._k8s_redis_window_seconds}, + ) + if not allowed: + self._emit_budget_exceeded_event() + if self.warn_only: + self._check_warn() + return + from shekel.exceptions import BudgetExceededError # noqa: PLC0415 + + raise BudgetExceededError( + self._spent, + self.max_usd or 0.0, + self._last_model, + self._last_tokens, + exceeded_counter=exceeded, + ) + # ------------------------------------------------------------------ def _restart_k8s_threads(self) -> None: """Restart stopped K8s poller/reporter on Budget re-entry (SHEK-27). diff --git a/shekel/integrations/kubernetes.py b/shekel/integrations/kubernetes.py index 01be4a2..d6196d5 100644 --- a/shekel/integrations/kubernetes.py +++ b/shekel/integrations/kubernetes.py @@ -124,6 +124,7 @@ def apply_k8s_config(budget: Budget) -> None: budget._k8s_redis_backend = RedisBackend(url=redis_url) budget._k8s_redis_name = cm.get("redis_key", f"shekel:{namespace}:{budget_name}") + budget._k8s_redis_window_seconds = float(cm.get("redis_window_seconds", "86400")) except ImportError: logger.warning( "shekel[k8s]: 'redis' package not installed — skipping Redis backend." diff --git a/tests/test_kubernetes_integration.py b/tests/test_kubernetes_integration.py index 6f8a30c..2cba362 100644 --- a/tests/test_kubernetes_integration.py +++ b/tests/test_kubernetes_integration.py @@ -587,6 +587,53 @@ def test_redis_backend_skipped_without_redis_url(self) -> None: assert not hasattr(b, "_k8s_redis_backend") or b._k8s_redis_backend is None + # ── Runtime enforcement ──────────────────────────────────────────────── + + def test_redis_check_and_add_called_on_spend(self) -> None: + from shekel import Budget + + b = Budget(max_usd=1.0) + mock_backend = MagicMock() + mock_backend.check_and_add.return_value = (True, None) + b._k8s_redis_backend = mock_backend + b._k8s_redis_name = "shekel:ns:budget" + + with b: + b._record_spend(0.10, "gpt-4o", {"input": 10, "output": 5}) + + mock_backend.check_and_add.assert_called_once_with( + "shekel:ns:budget", + {"usd": 0.10}, + {"usd": 1.0}, + {"usd": 86400.0}, + ) + + def test_redis_limit_raises_when_backend_rejects(self) -> None: + from shekel import Budget + from shekel.exceptions import BudgetExceededError + + b = Budget(max_usd=0.05) + mock_backend = MagicMock() + mock_backend.check_and_add.return_value = (False, "usd") + b._k8s_redis_backend = mock_backend + b._k8s_redis_name = "shekel:ns:budget" + + with b: + with pytest.raises(BudgetExceededError): + b._record_spend(0.10, "gpt-4o", {"input": 10, "output": 5}) + + def test_redis_warn_only_suppresses_raise(self) -> None: + from shekel import Budget + + b = Budget(max_usd=0.05, warn_only=True) + mock_backend = MagicMock() + mock_backend.check_and_add.return_value = (False, "usd") + b._k8s_redis_backend = mock_backend + b._k8s_redis_name = "shekel:ns:budget" + + with b: + b._record_spend(0.10, "gpt-4o", {"input": 10, "output": 5}) # must not raise + # --------------------------------------------------------------------------- # Per-pod cap From 6645684be05a06190fb9191fda7ce1c51e155592 Mon Sep 17 00:00:00 2001 From: Arie Radle Date: Wed, 10 Jun 2026 14:30:39 +0000 Subject: [PATCH 11/12] =?UTF-8?q?feat:=20SHEK-34=20=E2=80=94=20wire=20scop?= =?UTF-8?q?e=5Fgroup=5Fby=20(pod-label=20group=20detection)=20and=20scope?= =?UTF-8?q?=5Fmode=3Dshared=20(group-scoped=20Redis=20key)=20into=20runtim?= =?UTF-8?q?e?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Sonnet 4.6 --- shekel/_budget.py | 5 +- shekel/integrations/kubernetes.py | 49 +++++++++- tests/test_kubernetes_integration.py | 139 +++++++++++++++++++++++++++ 3 files changed, 186 insertions(+), 7 deletions(-) diff --git a/shekel/_budget.py b/shekel/_budget.py index 970a7cc..8e57545 100644 --- a/shekel/_budget.py +++ b/shekel/_budget.py @@ -312,6 +312,7 @@ def __init__( self._k8s_flush_every_seconds: float | None = None self._k8s_scope_mode: str | None = None self._k8s_scope_group_by: str | None = None + self._k8s_group_value: str = "" self._k8s_budget_name: str | None = None self._k8s_namespace: str | None = None self._k8s_poll_interval: float = 10.0 @@ -888,8 +889,6 @@ def _restart_k8s_threads(self) -> None: self._k8s_poller = poller if self._k8s_reporter is not None and not self._k8s_reporter.is_alive(): - import os # noqa: PLC0415 - from shekel.integrations.kubernetes import KubernetesSpendReporter # noqa: PLC0415 reporter = KubernetesSpendReporter( @@ -897,7 +896,7 @@ def _restart_k8s_threads(self) -> None: namespace=self._k8s_namespace or "", flush_every_seconds=self._k8s_flush_every_seconds or 60.0, flush_every_usd=self._k8s_flush_every_usd, - group_value=os.environ.get("SHEKEL_GROUP_VALUE", ""), + group_value=self._k8s_group_value, ) reporter.start() self._k8s_reporter = reporter diff --git a/shekel/integrations/kubernetes.py b/shekel/integrations/kubernetes.py index d6196d5..79166ba 100644 --- a/shekel/integrations/kubernetes.py +++ b/shekel/integrations/kubernetes.py @@ -66,6 +66,27 @@ def _fetch_configmap(budget_name: str, namespace: str) -> dict[str, str] | None: return None +def _read_pod_group_value(scope_group_by: str, namespace: str) -> str: + """Read the pod's label named *scope_group_by* and return its value.""" + pod_name = os.environ.get("HOSTNAME", "") + if not pod_name: + return "" + try: + import kubernetes # noqa: PLC0415 + + kubernetes.config.load_incluster_config() + v1 = kubernetes.client.CoreV1Api() + pod = v1.read_namespaced_pod(name=pod_name, namespace=namespace) + return str((pod.metadata.labels or {}).get(scope_group_by, "")) + except Exception as exc: + logger.warning( + "shekel[k8s]: Failed to read pod label %r for scope_group_by: %s", + scope_group_by, + exc, + ) + return "" + + def apply_k8s_config(budget: Budget) -> None: """Load K8s ConfigMap and apply values to *budget* (mutates in place). @@ -115,6 +136,17 @@ def apply_k8s_config(budget: Budget) -> None: if "per_pod_cap" in cm: budget._per_pod_cap_usd = float(cm["per_pod_cap"]) + # --- Scope resolution (SHEK-34): must run before Redis key construction --- + scope_group_by = cm.get("scope_group_by") + scope_mode = cm.get("scope_mode") + budget._k8s_scope_group_by = scope_group_by + budget._k8s_scope_mode = scope_mode + # env var takes priority over pod-label discovery (ConfigMap < env var pattern) + group_value = os.environ.get("SHEKEL_GROUP_VALUE", "") + if not group_value and scope_group_by: + group_value = _read_pod_group_value(scope_group_by, namespace) + budget._k8s_group_value = group_value + # --- Redis backend --- if cm.get("backend") == "redis": redis_url = os.environ.get("REDIS_URL") @@ -123,8 +155,20 @@ def apply_k8s_config(budget: Budget) -> None: from shekel.backends.redis import RedisBackend # noqa: PLC0415 budget._k8s_redis_backend = RedisBackend(url=redis_url) + # Default key is per-pod; scope_mode=shared promotes it to a group key budget._k8s_redis_name = cm.get("redis_key", f"shekel:{namespace}:{budget_name}") budget._k8s_redis_window_seconds = float(cm.get("redis_window_seconds", "86400")) + if scope_mode == "shared" and "redis_key" not in cm: + if budget._k8s_group_value: + budget._k8s_redis_name = ( + f"shekel:{namespace}:{budget_name}:{budget._k8s_group_value}" + ) + else: + logger.warning( + "shekel[k8s]: scope_mode=shared requires scope_group_by with a " + "resolvable pod label or SHEKEL_GROUP_VALUE; " + "falling back to per-pod Redis key." + ) except ImportError: logger.warning( "shekel[k8s]: 'redis' package not installed — skipping Redis backend." @@ -135,18 +179,15 @@ def apply_k8s_config(budget: Budget) -> None: budget._k8s_flush_every_seconds = ( float(cm["flush_every_seconds"]) if "flush_every_seconds" in cm else None ) - budget._k8s_scope_mode = cm.get("scope_mode") - budget._k8s_scope_group_by = cm.get("scope_group_by") # --- SHEK-17: Spend reporter (only when backend=k8s) --- if cm.get("backend") == "k8s": - group_value = os.environ.get("SHEKEL_GROUP_VALUE", "") reporter = KubernetesSpendReporter( budget_name=budget_name, namespace=namespace, flush_every_seconds=budget._k8s_flush_every_seconds or 60.0, flush_every_usd=budget._k8s_flush_every_usd, - group_value=group_value, + group_value=budget._k8s_group_value, ) reporter.start() budget._k8s_reporter = reporter diff --git a/tests/test_kubernetes_integration.py b/tests/test_kubernetes_integration.py index 2cba362..9701b10 100644 --- a/tests/test_kubernetes_integration.py +++ b/tests/test_kubernetes_integration.py @@ -909,6 +909,145 @@ def test_scope_group_by_stored(self) -> None: assert b._k8s_scope_group_by == "team" +# --------------------------------------------------------------------------- +# SHEK-34: scope_group_by / scope_mode runtime behaviour +# --------------------------------------------------------------------------- + + +def _budget_with_k8s_and_pod_labels( + configmap_data: dict[str, str], + pod_labels: dict[str, str], + extra_env: dict[str, str] | None = None, +) -> Any: + """Like _budget_with_k8s but also mocks read_namespaced_pod to return pod_labels.""" + from shekel._budget import Budget + + k8s_mock = _make_k8s_mock(configmap_data) + pod_mock = MagicMock() + pod_mock.metadata.labels = pod_labels + k8s_mock.client.CoreV1Api.return_value.read_namespaced_pod.return_value = pod_mock + + env = { + "KUBERNETES_SERVICE_HOST": "10.0.0.1", + "SHEKEL_BUDGET_NAME": "test-budget", + "HOSTNAME": "worker-pod-1", + **(extra_env or {}), + } + with patch.dict("os.environ", env, clear=False): + with patch.dict( + "sys.modules", + { + "kubernetes": k8s_mock, + "kubernetes.client": k8s_mock.client, + "kubernetes.config": k8s_mock.config, + }, + ): + with patch( + "builtins.open", + MagicMock( + return_value=MagicMock( + __enter__=MagicMock( + return_value=MagicMock(read=MagicMock(return_value="default")) + ), + __exit__=MagicMock(return_value=False), + ) + ), + ): + b = Budget() + return b + + +class TestScopeResolution: + def test_scope_group_by_reads_pod_label(self) -> None: + b = _budget_with_k8s_and_pod_labels( + {"scope_group_by": "team"}, + pod_labels={"team": "backend"}, + ) + assert b._k8s_group_value == "backend" + + def test_scope_group_by_env_var_overrides_pod_label(self) -> None: + b = _budget_with_k8s_and_pod_labels( + {"scope_group_by": "team"}, + pod_labels={"team": "backend"}, + extra_env={"SHEKEL_GROUP_VALUE": "frontend"}, + ) + assert b._k8s_group_value == "frontend" + + def test_scope_group_by_falls_back_gracefully(self) -> None: + # Pod read raises — group_value stays "" and no exception propagates. + from shekel._budget import Budget + + k8s_mock = _make_k8s_mock({"scope_group_by": "team"}) + k8s_mock.client.CoreV1Api.return_value.read_namespaced_pod.side_effect = OSError( + "forbidden" + ) + env = { + "KUBERNETES_SERVICE_HOST": "10.0.0.1", + "SHEKEL_BUDGET_NAME": "test-budget", + "HOSTNAME": "worker-pod-1", + } + with patch.dict("os.environ", env, clear=False): + with patch.dict( + "sys.modules", + { + "kubernetes": k8s_mock, + "kubernetes.client": k8s_mock.client, + "kubernetes.config": k8s_mock.config, + }, + ): + with patch( + "builtins.open", + MagicMock( + return_value=MagicMock( + __enter__=MagicMock( + return_value=MagicMock(read=MagicMock(return_value="default")) + ), + __exit__=MagicMock(return_value=False), + ) + ), + ): + b = Budget() + assert b._k8s_group_value == "" + + def test_scope_mode_shared_scopes_redis_key_to_group(self) -> None: + b = _budget_with_k8s_and_pod_labels( + {"scope_mode": "shared", "scope_group_by": "team", "backend": "redis"}, + pod_labels={"team": "svc-a"}, + extra_env={"REDIS_URL": "redis://localhost"}, + ) + assert b._k8s_redis_name == "shekel:default:test-budget:svc-a" + + def test_scope_mode_shared_warns_when_no_group(self, caplog: Any) -> None: + import logging + + with caplog.at_level(logging.WARNING, logger="shekel.integrations.kubernetes"): + b = _budget_with_k8s_and_pod_labels( + {"scope_mode": "shared", "backend": "redis"}, + pod_labels={}, # no scope_group_by, no SHEKEL_GROUP_VALUE + extra_env={"REDIS_URL": "redis://localhost"}, + ) + assert b._k8s_redis_name == "shekel:default:test-budget" + assert any("falling back to per-pod" in r.message for r in caplog.records) + + def test_scope_mode_per_pod_does_not_change_redis_key(self) -> None: + b = _budget_with_k8s_and_pod_labels( + {"scope_mode": "per_pod", "scope_group_by": "team", "backend": "redis"}, + pod_labels={"team": "svc-a"}, + extra_env={"REDIS_URL": "redis://localhost"}, + ) + assert b._k8s_redis_name == "shekel:default:test-budget" + + def test_reporter_uses_k8s_group_value_not_env_var(self) -> None: + # group_value must come from _k8s_group_value (resolved at init), not a + # fresh SHEKEL_GROUP_VALUE read at reporter creation time. + b = _budget_with_k8s_and_pod_labels( + {"scope_group_by": "team", "backend": "k8s"}, + pod_labels={"team": "payments"}, + ) + assert b._k8s_reporter is not None + assert b._k8s_reporter._group_value == "payments" + + # --------------------------------------------------------------------------- # SHEK-17: KubernetesSpendReporter # --------------------------------------------------------------------------- From c3b777c4944600b0d9a1a472fc16f27f10a7529f Mon Sep 17 00:00:00 2001 From: aradledessi Date: Wed, 17 Jun 2026 10:21:45 +0300 Subject: [PATCH 12/12] test: fix coverage gaps and litellm skip guard on feat/shek-16 - Fix _check_redis_limit raise path (lines 859-861): test used max_usd=0.05 which caused _check_limit to fire first, never reaching redis path - Add chain() tests to cover happy path and invalid-arg guard (lines 1362-1365) - Add _litellm_available() skipif to TestLiteLLMPatching and TestLiteLLMCostRecording so they skip when litellm optional dep is absent Co-Authored-By: Claude Sonnet 4.6 --- tests/providers/test_litellm_adapter.py | 14 ++++++++++++++ tests/test_budget.py | 20 ++++++++++++++++++++ tests/test_kubernetes_integration.py | 3 ++- 3 files changed, 36 insertions(+), 1 deletion(-) diff --git a/tests/providers/test_litellm_adapter.py b/tests/providers/test_litellm_adapter.py index 61330bd..75a93a3 100644 --- a/tests/providers/test_litellm_adapter.py +++ b/tests/providers/test_litellm_adapter.py @@ -9,6 +9,18 @@ from tests.providers.conftest import MockLiteLLMChunk, ProviderTestBase +def _litellm_available() -> bool: + try: + import litellm # noqa: F401 + + return True + except ImportError: + return False + + +_skip_no_litellm = pytest.mark.skipif(not _litellm_available(), reason="litellm not installed") + + class TestLiteLLMAdapterBasic(ProviderTestBase): def test_name_is_litellm(self): @@ -163,6 +175,7 @@ def stream_with_bad_usage(): assert len(chunks) == 2 +@_skip_no_litellm class TestLiteLLMPatching(ProviderTestBase): def test_install_patches_when_litellm_available(self): @@ -220,6 +233,7 @@ def test_remove_patches_safe_without_litellm(self): adapter.remove_patches() # Must not raise +@_skip_no_litellm class TestLiteLLMCostRecording(ProviderTestBase): def test_completion_records_cost(self): diff --git a/tests/test_budget.py b/tests/test_budget.py index 27e894d..cfc5e8a 100644 --- a/tests/test_budget.py +++ b/tests/test_budget.py @@ -290,3 +290,23 @@ def test_record_swallows_adapter_emit_exception() -> None: side_effect=RuntimeError("adapter crash"), ): _record(100, 50, "gpt-4o-mini") # must not raise + + +def test_chain_raises_on_nonpositive_max_usd() -> None: + from shekel import Budget + + b = Budget(max_usd=1.0) + with pytest.raises(ValueError, match="chain max_usd must be positive"): + b.chain("step", max_usd=0.0) + with pytest.raises(ValueError, match="chain max_usd must be positive"): + b.chain("step", max_usd=-1.0) + + +def test_chain_registers_component_budget() -> None: + from shekel import Budget + + b = Budget(max_usd=1.0) + result = b.chain("summarize", max_usd=0.50) + assert result is b + assert "summarize" in b._chain_budgets + assert b._chain_budgets["summarize"].max_usd == pytest.approx(0.50) diff --git a/tests/test_kubernetes_integration.py b/tests/test_kubernetes_integration.py index 9701b10..e9c2aab 100644 --- a/tests/test_kubernetes_integration.py +++ b/tests/test_kubernetes_integration.py @@ -612,7 +612,8 @@ def test_redis_limit_raises_when_backend_rejects(self) -> None: from shekel import Budget from shekel.exceptions import BudgetExceededError - b = Budget(max_usd=0.05) + # max_usd must be high enough that _check_limit() doesn't fire first + b = Budget(max_usd=5.0) mock_backend = MagicMock() mock_backend.check_and_add.return_value = (False, "usd") b._k8s_redis_backend = mock_backend