Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 57 additions & 28 deletions e2e_test/realtime/test_realtime_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
- Session lifecycle (connect, session.created, session.update)
- Text generation (single-turn and multi-turn conversations)
- Response cancellation mid-stream
- Response format validation (session.created, response.done, response.text.delta)
- Response format validation (session.created, response.done, response.output_text.delta)
- Error handling (invalid events, missing model, missing auth)

Prerequisites:
Expand Down Expand Up @@ -97,7 +97,7 @@ async def _collect_response_text(ws, *, timeout: float = RECV_TIMEOUT) -> str:
if event is None:
continue
etype = event.get("type", "")
if etype == "response.text.delta" and event.get("delta"):
if etype == "response.output_text.delta" and event.get("delta"):
parts.append(event["delta"])
elif etype == "response.done":
break
Expand All @@ -111,7 +111,12 @@ async def _realtime_session(ws_url: str, ws_headers: dict):
"""Connect, wait for session.created, configure text modality, yield ws."""
async with websockets.connect(ws_url, additional_headers=ws_headers) as ws:
await _recv_event(ws, event_type="session.created")
await ws.send(_make_event("session.update", session={"modalities": ["text"]}))
await ws.send(
_make_event(
"session.update",
session={"type": "realtime", "output_modalities": ["text"]},
)
)
await _recv_event(ws, event_type="session.updated")
yield ws

Expand Down Expand Up @@ -145,7 +150,6 @@ def ws_headers():
"""Build the WebSocket connection headers."""
return {
"Authorization": f"Bearer {OPENAI_API_KEY}",
"OpenAI-Beta": "realtime=v1",
}


Expand Down Expand Up @@ -179,11 +183,16 @@ def test_session_update(self, ws_url, ws_headers):
async def _run():
async with websockets.connect(ws_url, additional_headers=ws_headers) as ws:
await _recv_event(ws, event_type="session.created")
await ws.send(_make_event("session.update", session={"modalities": ["text"]}))
await ws.send(
_make_event(
"session.update",
session={"type": "realtime", "output_modalities": ["text"]},
)
)
event = await _recv_event(ws, event_type="session.updated")
assert event["type"] == "session.updated"
assert "session" in event
assert event["session"].get("modalities") == ["text"]
assert event["session"].get("output_modalities") == ["text"]
logger.info("Session updated successfully")

asyncio.run(_run())
Expand All @@ -197,7 +206,7 @@ async def _run():
await ws.send(
_make_event(
"response.create",
response={"modalities": ["text"]},
response={"output_modalities": ["text"]},
)
)

Expand All @@ -214,30 +223,39 @@ async def _run():
async with _realtime_session(ws_url, ws_headers) as ws:
# Turn 1
await ws.send(_make_user_message("My name is Alice."))
await ws.send(_make_event("response.create", response={"modalities": ["text"]}))
await ws.send(
_make_event("response.create", response={"output_modalities": ["text"]})
)
text1 = await _collect_response_text(ws)
assert len(text1) > 0
logger.info("Turn 1: %s", text1[:100])

# Turn 2 — model should remember the name
await ws.send(_make_user_message("What is my name?"))
await ws.send(_make_event("response.create", response={"modalities": ["text"]}))
await ws.send(
_make_event("response.create", response={"output_modalities": ["text"]})
)
text2 = await _collect_response_text(ws)
assert "alice" in text2.lower(), f"Expected 'Alice' in response, got: {text2}"
logger.info("Turn 2: %s", text2[:100])

asyncio.run(_run())

def test_conversation_item_created_event(self, ws_url, ws_headers):
"""Sending conversation.item.create should echo conversation.item.created."""
def test_conversation_item_added_event(self, ws_url, ws_headers):
"""Sending conversation.item.create should echo conversation.item.added.

GA renamed the legacy `conversation.item.created` event to
`conversation.item.added` (emitted when an item is added to the default
conversation).
"""

async def _run():
async with _realtime_session(ws_url, ws_headers) as ws:
await ws.send(_make_user_message("Hi"))
event = await _recv_event(ws, event_type="conversation.item.created")
assert event["type"] == "conversation.item.created"
event = await _recv_event(ws, event_type="conversation.item.added")
assert event["type"] == "conversation.item.added"
assert event["item"]["role"] == "user"
logger.info("conversation.item.created received: id=%s", event["item"].get("id"))
logger.info("conversation.item.added received: id=%s", event["item"].get("id"))

asyncio.run(_run())

Expand All @@ -249,10 +267,12 @@ async def _run():
await ws.send(
_make_user_message("Write a very long essay about the history of computing.")
)
await ws.send(_make_event("response.create", response={"modalities": ["text"]}))
await ws.send(
_make_event("response.create", response={"output_modalities": ["text"]})
)

# Wait for first delta to confirm streaming started
await _recv_event(ws, event_type="response.text.delta")
await _recv_event(ws, event_type="response.output_text.delta")

# Cancel mid-stream
await ws.send(_make_event("response.cancel"))
Expand All @@ -273,15 +293,20 @@ async def _run():
# Top-level fields
assert "event_id" in event, "Missing event_id"
assert event["type"] == "session.created"
# Session object
# Session object (GA shape)
session = event["session"]
assert isinstance(session, dict)
assert isinstance(session.get("id"), str)
assert len(session["id"]) > 0
assert isinstance(session.get("model"), str)
assert isinstance(session.get("modalities"), list)
assert isinstance(session.get("voice"), str)
assert isinstance(session.get("turn_detection"), (dict, type(None)))
assert isinstance(session.get("output_modalities"), list)
# In GA, voice/turn_detection moved under audio.{output,input}.
audio = session.get("audio")
assert isinstance(audio, dict), f"Expected session.audio dict, got: {audio!r}"
output = audio.get("output") or {}
assert isinstance(output.get("voice"), str)
input_cfg = audio.get("input") or {}
assert isinstance(input_cfg.get("turn_detection"), (dict, type(None)))
Comment on lines +308 to +309
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Nit: Unlike audio.output (where the isinstance(voice, str) assertion catches a missing key), a missing audio.input passes silently here — None or {} gives {}, then {}.get("turn_detection") is None, and isinstance(None, (dict, type(None))) is True.

Consider asserting input exists as a dict, mirroring the audio assertion on line 300:

Suggested change
input_cfg = audio.get("input") or {}
assert isinstance(input_cfg.get("turn_detection"), (dict, type(None)))
input_cfg = audio.get("input") or {}
assert isinstance(input_cfg, dict) and input_cfg, f"Expected session.audio.input dict, got: {input_cfg!r}"
assert isinstance(input_cfg.get("turn_detection"), (dict, type(None)))

logger.info(
"session.created schema OK: id=%s model=%s",
session["id"],
Expand All @@ -296,7 +321,9 @@ def test_response_done_format(self, ws_url, ws_headers):
async def _run():
async with _realtime_session(ws_url, ws_headers) as ws:
await ws.send(_make_user_message("Say hi."))
await ws.send(_make_event("response.create", response={"modalities": ["text"]}))
await ws.send(
_make_event("response.create", response={"output_modalities": ["text"]})
)

event = await _recv_event(ws, event_type="response.done")
# Top-level
Expand All @@ -309,14 +336,14 @@ async def _run():
assert resp.get("status") == "completed"
assert isinstance(resp.get("output"), list)
assert len(resp["output"]) > 0
# Output item
# Output item — GA shape uses content type "output_text".
item = resp["output"][0]
assert item.get("type") == "message"
assert item.get("role") == "assistant"
assert isinstance(item.get("content"), list)
assert len(item["content"]) > 0
content = item["content"][0]
assert content.get("type") == "text"
assert content.get("type") == "output_text"
assert isinstance(content.get("text"), str)
assert len(content["text"]) > 0
# Usage
Expand All @@ -333,12 +360,14 @@ async def _run():
asyncio.run(_run())

def test_response_text_delta_format(self, ws_url, ws_headers):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Nit: Method name still says text_delta but the event migrated to output_text.delta. Consider renaming for grep-ability:

Suggested change
def test_response_text_delta_format(self, ws_url, ws_headers):
def test_response_output_text_delta_format(self, ws_url, ws_headers):

"""Validate response.text.delta events have the expected schema."""
"""Validate response.output_text.delta events have the expected schema."""

async def _run():
async with _realtime_session(ws_url, ws_headers) as ws:
await ws.send(_make_user_message("Say hello."))
await ws.send(_make_event("response.create", response={"modalities": ["text"]}))
await ws.send(
_make_event("response.create", response={"output_modalities": ["text"]})
)

# Collect a few deltas and validate schema
delta_count = 0
Expand All @@ -351,7 +380,7 @@ async def _run():
event = _parse_event(raw)
if event is None:
continue
if event.get("type") == "response.text.delta":
if event.get("type") == "response.output_text.delta":
assert "event_id" in event
assert isinstance(event.get("delta"), str)
assert len(event["delta"]) > 0
Expand All @@ -363,8 +392,8 @@ async def _run():
elif event.get("type") == "response.done":
break

assert delta_count > 0, "Expected at least one response.text.delta"
logger.info("response.text.delta schema OK: %d deltas received", delta_count)
assert delta_count > 0, "Expected at least one response.output_text.delta"
logger.info("response.output_text.delta schema OK: %d deltas received", delta_count)

asyncio.run(_run())

Expand Down
7 changes: 4 additions & 3 deletions model_gateway/src/routers/openai/realtime/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,15 @@ pub async fn run_ws_proxy(
// Connect to upstream WebSocket with auth.
// Let tungstenite auto-add WebSocket handshake headers (Connection, Upgrade,
// Sec-WebSocket-Version, Sec-WebSocket-Key); we only add app-specific headers.
//
// Do not send `OpenAI-Beta: realtime=v1` — OpenAI's GA Realtime API rejects
// it with `beta_api_shape_disabled` ("The Realtime Beta API is no longer
// supported. Please use /v1/realtime for the GA API.").
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
let mut request = upstream_url.into_client_request()?;
request
.headers_mut()
.insert("Authorization", auth_header.parse()?);
request
.headers_mut()
.insert("OpenAI-Beta", "realtime=v1".parse()?);

// Build an explicit rustls TLS connector so we don't depend on the
// process-level CryptoProvider being installed.
Expand Down
Loading