diff --git a/py/src/braintrust/integrations/cohere/cassettes/latest/test_wrap_cohere_chat_stream_v2_sync_context_manager.yaml b/py/src/braintrust/integrations/cohere/cassettes/latest/test_wrap_cohere_chat_stream_v2_sync_context_manager.yaml new file mode 100644 index 00000000..229d9e4d --- /dev/null +++ b/py/src/braintrust/integrations/cohere/cassettes/latest/test_wrap_cohere_chat_stream_v2_sync_context_manager.yaml @@ -0,0 +1,106 @@ +interactions: +- request: + body: '{"model":"command-a-03-2025","messages":[{"role":"user","content":"Say + hi in one word."}],"max_tokens":8,"stream":true}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + Content-Length: + - '119' + Host: + - api.cohere.com + User-Agent: + - cohere/6.1.0 + X-Fern-Language: + - Python + X-Fern-Platform: + - darwin/25.2.0 + X-Fern-Runtime: + - python/3.14.3 + X-Fern-SDK-Name: + - cohere + X-Fern-SDK-Version: + - 6.1.0 + content-type: + - application/json + method: POST + uri: https://api.cohere.com/v2/chat + response: + body: + string: 'event: message-start + + data: {"id":"b4256f33-1304-4943-b795-7f0d56d1fec9","type":"message-start","delta":{"message":{"role":"assistant","content":[],"tool_plan":"","tool_calls":[],"citations":[]}}} + + + event: content-start + + data: {"type":"content-start","index":0,"delta":{"message":{"content":{"type":"text","text":""}}}} + + + event: content-delta + + data: {"type":"content-delta","index":0,"delta":{"message":{"content":{"text":"Hi"}}}} + + + event: content-delta + + data: {"type":"content-delta","index":0,"delta":{"message":{"content":{"text":"!"}}}} + + + event: content-end + + data: {"type":"content-end","index":0} + + + event: message-end + + data: {"type":"message-end","delta":{"finish_reason":"COMPLETE","usage":{"billed_units":{"input_tokens":6,"output_tokens":2},"tokens":{"input_tokens":501,"output_tokens":4},"cached_tokens":0}}} + + + data: [DONE] + + + ' + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Transfer-Encoding: + - chunked + Via: + - 1.1 google + access-control-expose-headers: + - X-Debug-Trace-ID + cache-control: + - no-cache, no-store, no-transform, must-revalidate, private, max-age=0 + content-type: + - text/event-stream + date: + - Thu, 16 Apr 2026 22:25:13 GMT + expires: + - Thu, 01 Jan 1970 00:00:00 GMT + pragma: + - no-cache + server: + - envoy + vary: + - Origin + x-accel-expires: + - '0' + x-debug-trace-id: + - 6ad8e9afe35e1e8627c7779791b16cc4 + x-endpoint-monthly-call-limit: + - '1000' + x-envoy-upstream-service-time: + - '303' + x-trial-endpoint-call-limit: + - '20' + x-trial-endpoint-call-remaining: + - '17' + status: + code: 200 + message: OK +version: 1 diff --git a/py/src/braintrust/integrations/cohere/test_cohere.py b/py/src/braintrust/integrations/cohere/test_cohere.py index 0bbeb0bd..d131039b 100644 --- a/py/src/braintrust/integrations/cohere/test_cohere.py +++ b/py/src/braintrust/integrations/cohere/test_cohere.py @@ -567,6 +567,34 @@ def test_wrap_cohere_chat_stream_v2_sync(memory_logger): assert metrics.get("completion_tokens", 0) > 0 +@pytest.mark.vcr +def test_wrap_cohere_chat_stream_v2_sync_context_manager(memory_logger): + assert not memory_logger.pop() + client = wrap_cohere(_v2_client(require_methods=("chat_stream",))) + + start = time.time() + events = [] + with client.chat_stream( + model=CHAT_MODEL, + messages=[{"role": "user", "content": "Say hi in one word."}], + max_tokens=8, + ) as stream: + for event in stream: + events.append(event) + end = time.time() + + assert events + + spans = memory_logger.pop() + assert len(spans) == 1 + span = spans[0] + assert span["span_attributes"]["name"] == "cohere.chat_stream" + assert span["metadata"]["provider"] == "cohere" + assert span["metadata"]["model"] == CHAT_MODEL + assert span["metrics"]["start"] <= span["metrics"]["end"] + assert start <= span["metrics"]["start"] <= span["metrics"]["end"] <= end + + @pytest.mark.vcr def test_wrap_cohere_chat_stream_v2_rag_citations(memory_logger): if os.environ.get("BRAINTRUST_TEST_PACKAGE_VERSION") != "latest": diff --git a/py/src/braintrust/integrations/cohere/tracing.py b/py/src/braintrust/integrations/cohere/tracing.py index dfc89fd2..0f563616 100644 --- a/py/src/braintrust/integrations/cohere/tracing.py +++ b/py/src/braintrust/integrations/cohere/tracing.py @@ -840,6 +840,26 @@ def _finish(self, error: BaseException | None = None) -> None: class _TracedChatStream(_ChatStreamTracker): """Wrap a sync chat-stream iterator so exhaustion logs the aggregated span.""" + def __enter__(self): + context_manager = self._iterator + enter = getattr(context_manager, "__enter__", None) + if enter is not None: + self._iterator = enter() + self._context_manager = context_manager + return self + + def __exit__(self, exc_type, exc_value, traceback): + suppress = False + context_manager = getattr(self, "_context_manager", self._iterator) + exit_method = getattr(context_manager, "__exit__", None) + if exit_method is not None: + suppress = bool(exit_method(exc_type, exc_value, traceback)) + if exc_value is not None: + self._finish(error=exc_value) + else: + self._finish() + return suppress + def __iter__(self): return self @@ -859,6 +879,26 @@ def __next__(self): class _AsyncTracedChatStream(_ChatStreamTracker): """Async counterpart of :class:`_TracedChatStream`.""" + async def __aenter__(self): + context_manager = self._iterator + aenter = getattr(context_manager, "__aenter__", None) + if aenter is not None: + self._iterator = await aenter() + self._context_manager = context_manager + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + suppress = False + context_manager = getattr(self, "_context_manager", self._iterator) + aexit = getattr(context_manager, "__aexit__", None) + if aexit is not None: + suppress = bool(await aexit(exc_type, exc_value, traceback)) + if exc_value is not None: + self._finish(error=exc_value) + else: + self._finish() + return suppress + def __aiter__(self): return self