Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
interactions:
- request:
body: '{"model":"command-a-03-2025","messages":[{"role":"user","content":"Say
hi in one word."}],"max_tokens":8,"stream":true}'
headers:
Accept:
- '*/*'
Accept-Encoding:
- gzip, deflate
Connection:
- keep-alive
Content-Length:
- '119'
Host:
- api.cohere.com
User-Agent:
- cohere/6.1.0
X-Fern-Language:
- Python
X-Fern-Platform:
- darwin/25.2.0
X-Fern-Runtime:
- python/3.14.3
X-Fern-SDK-Name:
- cohere
X-Fern-SDK-Version:
- 6.1.0
content-type:
- application/json
method: POST
uri: https://api.cohere.com/v2/chat
response:
body:
string: 'event: message-start

data: {"id":"b4256f33-1304-4943-b795-7f0d56d1fec9","type":"message-start","delta":{"message":{"role":"assistant","content":[],"tool_plan":"","tool_calls":[],"citations":[]}}}


event: content-start

data: {"type":"content-start","index":0,"delta":{"message":{"content":{"type":"text","text":""}}}}


event: content-delta

data: {"type":"content-delta","index":0,"delta":{"message":{"content":{"text":"Hi"}}}}


event: content-delta

data: {"type":"content-delta","index":0,"delta":{"message":{"content":{"text":"!"}}}}


event: content-end

data: {"type":"content-end","index":0}


event: message-end

data: {"type":"message-end","delta":{"finish_reason":"COMPLETE","usage":{"billed_units":{"input_tokens":6,"output_tokens":2},"tokens":{"input_tokens":501,"output_tokens":4},"cached_tokens":0}}}


data: [DONE]


'
headers:
Alt-Svc:
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
Transfer-Encoding:
- chunked
Via:
- 1.1 google
access-control-expose-headers:
- X-Debug-Trace-ID
cache-control:
- no-cache, no-store, no-transform, must-revalidate, private, max-age=0
content-type:
- text/event-stream
date:
- Thu, 16 Apr 2026 22:25:13 GMT
expires:
- Thu, 01 Jan 1970 00:00:00 GMT
pragma:
- no-cache
server:
- envoy
vary:
- Origin
x-accel-expires:
- '0'
x-debug-trace-id:
- 6ad8e9afe35e1e8627c7779791b16cc4
x-endpoint-monthly-call-limit:
- '1000'
x-envoy-upstream-service-time:
- '303'
x-trial-endpoint-call-limit:
- '20'
x-trial-endpoint-call-remaining:
- '17'
status:
code: 200
message: OK
version: 1
28 changes: 28 additions & 0 deletions py/src/braintrust/integrations/cohere/test_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,34 @@ def test_wrap_cohere_chat_stream_v2_sync(memory_logger):
assert metrics.get("completion_tokens", 0) > 0


@pytest.mark.vcr
def test_wrap_cohere_chat_stream_v2_sync_context_manager(memory_logger):
assert not memory_logger.pop()
client = wrap_cohere(_v2_client(require_methods=("chat_stream",)))

start = time.time()
events = []
with client.chat_stream(
model=CHAT_MODEL,
messages=[{"role": "user", "content": "Say hi in one word."}],
max_tokens=8,
) as stream:
for event in stream:
events.append(event)
end = time.time()

assert events

spans = memory_logger.pop()
assert len(spans) == 1
span = spans[0]
assert span["span_attributes"]["name"] == "cohere.chat_stream"
assert span["metadata"]["provider"] == "cohere"
assert span["metadata"]["model"] == CHAT_MODEL
assert span["metrics"]["start"] <= span["metrics"]["end"]
assert start <= span["metrics"]["start"] <= span["metrics"]["end"] <= end


@pytest.mark.vcr
def test_wrap_cohere_chat_stream_v2_rag_citations(memory_logger):
if os.environ.get("BRAINTRUST_TEST_PACKAGE_VERSION") != "latest":
Expand Down
40 changes: 40 additions & 0 deletions py/src/braintrust/integrations/cohere/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,26 @@ def _finish(self, error: BaseException | None = None) -> None:
class _TracedChatStream(_ChatStreamTracker):
"""Wrap a sync chat-stream iterator so exhaustion logs the aggregated span."""

def __enter__(self):
context_manager = self._iterator
enter = getattr(context_manager, "__enter__", None)
if enter is not None:
self._iterator = enter()
self._context_manager = context_manager
return self

def __exit__(self, exc_type, exc_value, traceback):
suppress = False
context_manager = getattr(self, "_context_manager", self._iterator)
exit_method = getattr(context_manager, "__exit__", None)
if exit_method is not None:
suppress = bool(exit_method(exc_type, exc_value, traceback))
if exc_value is not None:
self._finish(error=exc_value)
else:
self._finish()
return suppress

def __iter__(self):
return self

Expand All @@ -859,6 +879,26 @@ def __next__(self):
class _AsyncTracedChatStream(_ChatStreamTracker):
"""Async counterpart of :class:`_TracedChatStream`."""

async def __aenter__(self):
context_manager = self._iterator
aenter = getattr(context_manager, "__aenter__", None)
if aenter is not None:
self._iterator = await aenter()
self._context_manager = context_manager
return self

async def __aexit__(self, exc_type, exc_value, traceback):
suppress = False
context_manager = getattr(self, "_context_manager", self._iterator)
aexit = getattr(context_manager, "__aexit__", None)
if aexit is not None:
suppress = bool(await aexit(exc_type, exc_value, traceback))
if exc_value is not None:
self._finish(error=exc_value)
else:
self._finish()
return suppress

def __aiter__(self):
return self

Expand Down
Loading