diff --git a/e2e_test/mlx/test_mlx_backend.py b/e2e_test/mlx/test_mlx_backend.py index 00d9be604..401c02f53 100644 --- a/e2e_test/mlx/test_mlx_backend.py +++ b/e2e_test/mlx/test_mlx_backend.py @@ -13,6 +13,7 @@ import platform import sys +import openai import pytest pytestmark = pytest.mark.skipif( @@ -21,6 +22,36 @@ ) +def collect_streamed_completion(stream): + """Collect all text and the final choice from a streaming completion response.""" + chunks = list(stream) + text = "".join(c.choices[0].text for c in chunks if c.choices and c.choices[0].text) + final_choice = next( + c.choices[0] for c in reversed(chunks) if c.choices and c.choices[0].finish_reason + ) + return text, final_choice + + +def assert_stop_text_trimmed(text, stop_text): + assert stop_text not in text, f"Stop text {stop_text!r} should not appear in output: {text!r}" + + +def assert_matched_stop(choice, expected): + actual = getattr(choice, "matched_stop", None) + assert actual == expected, f"Expected matched_stop={expected!r}, got {actual!r}" + + +def assert_api_error(err, expected_status: int, expected_code: str) -> None: + assert err.status_code == expected_status, ( + f"Expected HTTP {expected_status}, got {err.status_code}" + ) + body = getattr(err, "body", None) or {} + error_str = str(body) + str(getattr(err, "message", "")) + assert expected_code in error_str, ( + f"Expected {expected_code!r} in error body, got: {error_str!r}" + ) + + WEATHER_TOOL = { "type": "function", "function": { @@ -60,6 +91,9 @@ class TestMlxBackend: # actual content within max_tokens. NO_THINKING = {"chat_template_kwargs": {"enable_thinking": False}} + STOP_SEQUENCE_TEST_PROMPT = "Repeat: 1 2 3 hello world 4 5 6 7" + SINGLE_STRING_STOP = "6" + def test_basic_chat(self, model, api_client): response = api_client.chat.completions.create( model=model, @@ -153,3 +187,66 @@ def test_max_tokens_finish_reason(self, model, api_client): ) assert response.choices[0].finish_reason == "length" assert response.usage.completion_tokens == 10 + + def test_chat_stop_string_non_streaming(self, model, api_client): + response = api_client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": self.STOP_SEQUENCE_TEST_PROMPT}], + max_tokens=160, + temperature=0, + stop=[self.SINGLE_STRING_STOP], + extra_body=self.NO_THINKING, + ) + choice = response.choices[0] + assert choice.finish_reason == "stop" + assert_stop_text_trimmed(choice.message.content or "", self.SINGLE_STRING_STOP) + assert_matched_stop(choice, self.SINGLE_STRING_STOP) + + def test_completion_stop_string_non_streaming(self, model, api_client): + response = api_client.completions.create( + model=model, + prompt=self.STOP_SEQUENCE_TEST_PROMPT, + max_tokens=160, + temperature=0, + stop=[self.SINGLE_STRING_STOP], + ) + choice = response.choices[0] + assert choice.finish_reason == "stop" + assert_stop_text_trimmed(choice.text, self.SINGLE_STRING_STOP) + assert_matched_stop(choice, self.SINGLE_STRING_STOP) + + def test_chat_rejects_multi_token_stop_string(self, model, api_client): + with pytest.raises((openai.BadRequestError, openai.APIStatusError)) as exc_info: + api_client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": self.STOP_SEQUENCE_TEST_PROMPT}], + max_tokens=10, + stop=["hello world"], + extra_body=self.NO_THINKING, + ) + assert_api_error(exc_info.value, 400, "unsupported_stop_string") + + def test_completion_rejects_multi_token_stop_string(self, model, api_client): + with pytest.raises((openai.BadRequestError, openai.APIStatusError)) as exc_info: + api_client.completions.create( + model=model, + prompt=self.STOP_SEQUENCE_TEST_PROMPT, + max_tokens=10, + stop=["hello world"], + ) + assert_api_error(exc_info.value, 400, "unsupported_stop_string") + + def test_completion_stop_string_streaming_final_chunk(self, model, api_client): + """Streaming completion: final chunk has finish_reason='stop' and matched_stop set.""" + stream = api_client.completions.create( + model=model, + prompt=self.STOP_SEQUENCE_TEST_PROMPT, + max_tokens=160, + temperature=0, + stop=[self.SINGLE_STRING_STOP], + stream=True, + ) + text, final_choice = collect_streamed_completion(stream) + assert final_choice.finish_reason == "stop" + assert_stop_text_trimmed(text, self.SINGLE_STRING_STOP) + assert_matched_stop(final_choice, self.SINGLE_STRING_STOP)