diff --git a/mlx_lm/chat.py b/mlx_lm/chat.py index 44a4bed83..cc311f90c 100644 --- a/mlx_lm/chat.py +++ b/mlx_lm/chat.py @@ -1,6 +1,7 @@ # Copyright © 2023-2024 Apple Inc. import argparse +import readline # noqa: F401 # Enables terminal line editing/history on rank 0. import mlx.core as mx @@ -16,7 +17,33 @@ DEFAULT_XTC_THRESHOLD = 0.0 DEFAULT_SEED = 0 DEFAULT_MAX_TOKENS = 256 +DEFAULT_RENDER_WINDOW_SIZE = 20 +DEFAULT_REFRESH_RATE = 10 DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" +DEFAULT_SYSTEM_PROMPT = ( + "You are a helpful assistant. Your responses are rendered in a terminal with " + "Markdown support. Feel free to use Markdown formatting when appropriate: " + "**bold**, *italic*, `inline code`, code blocks with syntax highlighting " + "(```language), bullet lists, numbered lists, and headers." +) + + +def broadcast_string( + value: str, group: mx.distributed.Group, src: int = 0 +) -> str: + """Broadcast a UTF-8 string from src to every rank in group.""" + if group.size() == 1: + return value + if group.rank() == src: + data = mx.array(value.encode("utf-8")) + mx.eval(mx.distributed.all_sum(data.size, group=group)) + mx.eval(mx.distributed.all_sum(data, group=group)) + return value + + size = mx.distributed.all_sum(0, group=group).item() + data = mx.distributed.all_sum(mx.zeros(size, dtype=mx.uint8), group=group) + mx.eval(data) + return bytes(data).decode("utf-8") def setup_arg_parser(): @@ -53,8 +80,8 @@ def setup_arg_parser(): parser.add_argument( "--xtc-threshold", type=float, - default=0.0, - help="Thresold the probs of each next token candidate to be sampled by XTC", + default=DEFAULT_XTC_THRESHOLD, + help="Threshold the probs of each next token candidate to be sampled by XTC", ) parser.add_argument( "--seed", @@ -78,13 +105,31 @@ def setup_arg_parser(): parser.add_argument( "--system-prompt", default=None, - help="System prompt to be used for the chat template", + help="System prompt to be used for the chat template " + "(replaces the default Markdown-aware prompt)", + ) + parser.add_argument( + "--no-system-prompt", + action="store_true", + help="Disable the default system prompt entirely", ) parser.add_argument( "--pipeline", action="store_true", help="Use pipelining instead of tensor parallelism", ) + parser.add_argument( + "--window-size", + type=int, + default=DEFAULT_RENDER_WINDOW_SIZE, + help="The number of recent rendered lines to keep in the live panel", + ) + parser.add_argument( + "--refresh-rate", + type=int, + default=DEFAULT_REFRESH_RATE, + help="The live panel refresh rate during generation", + ) return parser @@ -115,7 +160,11 @@ def main(): with ChatUI(args, rank=rank) as ui: prompt_cache = make_prompt_cache(model, args.max_kv_size) while True: - query = ui.prompt() + query = ui.prompt() if rank == 0 else "" + query = broadcast_string(query, group).strip() + + if not query: + continue if query == "q": ui.say_bye() break @@ -127,7 +176,10 @@ def main(): ui.say_help() continue messages = [] - if args.system_prompt is not None: + if not args.no_system_prompt: + system_content = args.system_prompt or DEFAULT_SYSTEM_PROMPT + messages.append({"role": "system", "content": system_content}) + elif args.system_prompt: messages.append({"role": "system", "content": args.system_prompt}) messages.append({"role": "user", "content": query}) prompt = tokenizer.apply_chat_template( diff --git a/mlx_lm/cli_ui.py b/mlx_lm/cli_ui.py index 133d1190e..b7c404136 100644 --- a/mlx_lm/cli_ui.py +++ b/mlx_lm/cli_ui.py @@ -1,6 +1,7 @@ # Copyright © 2024 Apple Inc. import re +import readline # noqa: F401 # Enables terminal line editing/history. import shutil import sys from contextlib import contextmanager @@ -9,6 +10,8 @@ import mlx.core as mx from rich.box import ROUNDED from rich.console import Console +from rich.live import Live +from rich.markdown import Markdown from rich.panel import Panel from rich.progress import Progress, ProgressColumn, TextColumn from rich.text import Text @@ -51,6 +54,8 @@ def make_console() -> Console: theme=_make_theme(), highlight=False, color_system="truecolor", + force_terminal=True, + force_interactive=True, width=_terminal_width(), ) @@ -256,6 +261,10 @@ def __init__(self, args, rank: int = 0): self._rank = rank self._args = args self._console = make_console() + self._response_text = "" + self._live = None + self._window_size = max(getattr(args, "window_size", 20), 1) + self._refresh_rate = max(getattr(args, "refresh_rate", 10), 1) def __enter__(self): if self._rank == 0: @@ -294,16 +303,59 @@ def say_help(self): if self._rank == 0: print_chat_help(self._console) + def _display_text(self) -> str: + lines = self._response_text.splitlines(keepends=True) + if len(lines) > self._window_size: + return "".join(lines[-self._window_size :]) + return self._response_text + + def _ensure_live(self): + if self._rank != 0 or self._live is not None: + return + self._live = Live( + Panel( + Markdown(""), + title="[ui.accent]generating[/ui.accent]", + border_style="ui.border", + box=ROUNDED, + ), + console=self._console, + refresh_per_second=self._refresh_rate, + transient=True, + ) + self._live.start() + def stream_token(self, text: str): - rprint(text, flush=True, end="") + if self._rank != 0: + return + self._ensure_live() + self._response_text += text + self._live.update( + Panel( + Markdown(self._display_text()), + title="[ui.accent]generating[/ui.accent]", + border_style="ui.border", + box=ROUNDED, + ) + ) def end_turn(self, response): - rprint() # newline after the streamed line + if self._live is not None: + self._live.stop() + self._live = None if self._rank != 0 or response is None: + self._response_text = "" return + self._console.print(Markdown(self._response_text)) + if getattr(response, "finish_reason", None) == "length": + self._console.print( + f" [ui.warn]output truncated[/ui.warn] " + f"[ui.muted](max tokens: {self._args.max_tokens})[/ui.muted]" + ) self._console.print( f" [ui.muted]{response.generation_tokens} tokens · " f"{response.generation_tps:.1f} tok/s · " f"prompt {response.prompt_tps:.1f} tok/s · " f"peak {response.peak_memory:.2f} GB[/ui.muted]" ) + self._response_text = "" diff --git a/mlx_lm/sample_utils.py b/mlx_lm/sample_utils.py index 05a45fc60..53f6d3f5f 100644 --- a/mlx_lm/sample_utils.py +++ b/mlx_lm/sample_utils.py @@ -141,7 +141,7 @@ def apply_top_k( vocab_size = logprobs.shape[-1] if not isinstance(top_k, int) or not (0 < top_k < vocab_size): raise ValueError( - f"`top_k` has to be an integer in the (0, {vocab_size}] interval," + f"`top_k` has to be an integer in the (0, {vocab_size}) interval," f" but is {top_k}." ) mask_idx = mx.argpartition(-logprobs, kth=top_k - 1, axis=-1)[..., top_k:] diff --git a/setup.py b/setup.py index dd769a706..42ea22a2d 100644 --- a/setup.py +++ b/setup.py @@ -31,6 +31,7 @@ "protobuf", "pyyaml", "jinja2", + "rich", ], packages=[ "mlx_lm", diff --git a/tests/test_chat.py b/tests/test_chat.py index b4ddcf3ae..39a439745 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -2,7 +2,7 @@ import unittest from unittest.mock import MagicMock, patch -from mlx_lm.chat import setup_arg_parser +from mlx_lm.chat import broadcast_string, setup_arg_parser class TestChat(unittest.TestCase): @@ -42,6 +42,10 @@ def test_setup_arg_parser_all_args(self): "512", "--system-prompt", "You are a helpful assistant.", + "--window-size", + "30", + "--refresh-rate", + "12", ] ) @@ -55,32 +59,90 @@ def test_setup_arg_parser_all_args(self): self.assertEqual(args.max_kv_size, 1024) self.assertEqual(args.max_tokens, 512) self.assertEqual(args.system_prompt, "You are a helpful assistant.") + self.assertEqual(args.window_size, 30) + self.assertEqual(args.refresh_rate, 12) - @patch("mlx_lm.chat.load") + def test_broadcast_string_single_rank_returns_input(self): + group = MagicMock() + group.size.return_value = 1 + + self.assertEqual(broadcast_string("hello", group), "hello") + + @patch("mlx_lm.chat.sharded_load") + @patch("mlx_lm.chat.make_prompt_cache") + @patch("mlx_lm.chat.stream_generate") + @patch("mlx_lm.chat.broadcast_string") + @patch("mlx_lm.chat.ChatUI") + def test_root_rank_prompts_then_broadcasts_input( + self, + mock_chat_ui, + mock_broadcast_string, + mock_stream_generate, + mock_make_prompt_cache, + mock_sharded_load, + ): + from mlx_lm.chat import main + + group = MagicMock() + group.rank.return_value = 0 + group.size.return_value = 2 + + mock_sharded_load.return_value = (MagicMock(), MagicMock()) + mock_make_prompt_cache.return_value = MagicMock() + mock_broadcast_string.return_value = "q" + + ui = MagicMock() + ui.prompt.return_value = "q" + mock_chat_ui.return_value.__enter__.return_value = ui + + with patch("mlx_lm.chat.mx.distributed.init", return_value=group), patch( + "sys.argv", ["chat.py"] + ): + main() + + ui.prompt.assert_called_once() + mock_broadcast_string.assert_called_once_with("q", group) + mock_stream_generate.assert_not_called() + + def test_no_system_prompt_flag(self): + parser = setup_arg_parser() + args = parser.parse_args(["--no-system-prompt"]) + + self.assertTrue(args.no_system_prompt) + + @patch("mlx_lm.chat.sharded_load") @patch("mlx_lm.chat.make_prompt_cache") @patch("mlx_lm.chat.stream_generate") - @patch("builtins.input") - @patch("builtins.print") + @patch("mlx_lm.chat.broadcast_string") + @patch("mlx_lm.chat.ChatUI") def test_system_prompt_in_messages( self, - mock_print, - mock_input, + mock_chat_ui, + mock_broadcast_string, mock_stream_generate, mock_make_prompt_cache, - mock_load, + mock_sharded_load, ): from mlx_lm.chat import main + group = MagicMock() + group.rank.return_value = 0 + group.size.return_value = 2 + # Mock the model and tokenizer mock_model = MagicMock() mock_tokenizer = MagicMock() mock_tokenizer.apply_chat_template.return_value = "processed_prompt" - mock_load.return_value = (mock_model, mock_tokenizer) + mock_sharded_load.return_value = (mock_model, mock_tokenizer) # Mock prompt cache mock_prompt_cache = MagicMock() mock_make_prompt_cache.return_value = mock_prompt_cache + ui = MagicMock() + ui.prompt.side_effect = ["What is the weather?", "q"] + mock_chat_ui.return_value.__enter__.return_value = ui + # Mock stream_generate to return some responses mock_response = MagicMock() mock_response.text = "Hello there!" @@ -89,12 +151,13 @@ def test_system_prompt_in_messages( mock_response.prompt_tps = 1.0 mock_response.peak_memory = 1.0 mock_stream_generate.return_value = [mock_response] - - # Mock user input: first a question, then 'q' to quit - mock_input.side_effect = ["What is the weather?", "q"] + mock_broadcast_string.side_effect = lambda query, *_args, **_kwargs: query # Test with system prompt with patch( + "mlx_lm.chat.mx.distributed.init", + return_value=group, + ), patch( "sys.argv", ["chat.py", "--system-prompt", "You are a weather assistant."] ): try: @@ -114,33 +177,85 @@ def test_system_prompt_in_messages( self.assertEqual(call_args[0]["content"], "You are a weather assistant.") self.assertEqual(call_args[1]["role"], "user") self.assertEqual(call_args[1]["content"], "What is the weather?") + mock_broadcast_string.assert_any_call("What is the weather?", group) - @patch("mlx_lm.chat.load") + @patch("mlx_lm.chat.sharded_load") @patch("mlx_lm.chat.make_prompt_cache") @patch("mlx_lm.chat.stream_generate") - @patch("builtins.input") - @patch("builtins.print") + @patch("mlx_lm.chat.broadcast_string") + @patch("mlx_lm.chat.ChatUI") + def test_default_system_prompt_in_messages( + self, + mock_chat_ui, + mock_broadcast_string, + mock_stream_generate, + mock_make_prompt_cache, + mock_sharded_load, + ): + from mlx_lm.chat import DEFAULT_SYSTEM_PROMPT, main + + group = MagicMock() + group.rank.return_value = 0 + group.size.return_value = 2 + + mock_model = MagicMock() + mock_tokenizer = MagicMock() + mock_tokenizer.apply_chat_template.return_value = "processed_prompt" + mock_sharded_load.return_value = (mock_model, mock_tokenizer) + mock_make_prompt_cache.return_value = MagicMock() + + ui = MagicMock() + ui.prompt.side_effect = ["Hello", "q"] + mock_chat_ui.return_value.__enter__.return_value = ui + + mock_response = MagicMock() + mock_response.text = "Hi!" + mock_response.generation_tokens = 1 + mock_response.generation_tps = 1.0 + mock_response.prompt_tps = 1.0 + mock_response.peak_memory = 1.0 + mock_stream_generate.return_value = [mock_response] + mock_broadcast_string.side_effect = lambda query, *_args, **_kwargs: query + + with patch("mlx_lm.chat.mx.distributed.init", return_value=group), patch( + "sys.argv", ["chat.py"] + ): + main() + + call_args = mock_tokenizer.apply_chat_template.call_args[0][0] + self.assertEqual(call_args[0]["role"], "system") + self.assertEqual(call_args[0]["content"], DEFAULT_SYSTEM_PROMPT) + self.assertEqual(call_args[1]["content"], "Hello") + + @patch("mlx_lm.chat.sharded_load") + @patch("mlx_lm.chat.make_prompt_cache") + @patch("mlx_lm.chat.stream_generate") + @patch("mlx_lm.chat.broadcast_string") + @patch("mlx_lm.chat.ChatUI") def test_no_system_prompt_in_messages( self, - mock_print, - mock_input, + mock_chat_ui, + mock_broadcast_string, mock_stream_generate, mock_make_prompt_cache, - mock_load, + mock_sharded_load, ): from mlx_lm.chat import main - # Mock the model and tokenizer + group = MagicMock() + group.rank.return_value = 0 + group.size.return_value = 2 + mock_model = MagicMock() mock_tokenizer = MagicMock() mock_tokenizer.apply_chat_template.return_value = "processed_prompt" - mock_load.return_value = (mock_model, mock_tokenizer) + mock_sharded_load.return_value = (mock_model, mock_tokenizer) + mock_make_prompt_cache.return_value = MagicMock() - # Mock prompt cache - mock_prompt_cache = MagicMock() - mock_make_prompt_cache.return_value = mock_prompt_cache + ui = MagicMock() + ui.prompt.side_effect = ["What is the weather?", "q"] + mock_chat_ui.return_value.__enter__.return_value = ui - # Mock stream_generate to return some responses mock_response = MagicMock() mock_response.text = "Hello there!" mock_response.generation_tokens = 1 @@ -148,27 +263,55 @@ def test_no_system_prompt_in_messages( mock_response.prompt_tps = 1.0 mock_response.peak_memory = 1.0 mock_stream_generate.return_value = [mock_response] + mock_broadcast_string.side_effect = lambda query, *_args, **_kwargs: query + + with patch("mlx_lm.chat.mx.distributed.init", return_value=group), patch( + "sys.argv", ["chat.py", "--no-system-prompt"] + ): + main() + + call_args = mock_tokenizer.apply_chat_template.call_args[0][0] + self.assertEqual(len(call_args), 1) + self.assertEqual(call_args[0]["role"], "user") + self.assertEqual(call_args[0]["content"], "What is the weather?") - # Mock user input: first a question, then 'q' to quit - mock_input.side_effect = ["What is the weather?", "q"] + @patch("mlx_lm.chat.sharded_load") + @patch("mlx_lm.chat.make_prompt_cache") + @patch("mlx_lm.chat.stream_generate") + @patch("mlx_lm.chat.broadcast_string") + @patch("mlx_lm.chat.ChatUI") + def test_non_root_rank_uses_broadcast_without_prompt( + self, + mock_chat_ui, + mock_broadcast_string, + mock_stream_generate, + mock_make_prompt_cache, + mock_sharded_load, + ): + from mlx_lm.chat import main + + group = MagicMock() + group.rank.return_value = 1 + group.size.return_value = 2 + + mock_sharded_load.return_value = (MagicMock(), MagicMock()) + mock_make_prompt_cache.return_value = MagicMock() + mock_broadcast_string.return_value = "q" + + ui = MagicMock() + mock_chat_ui.return_value.__enter__.return_value = ui - # Test without system prompt - with patch("sys.argv", ["chat.py"]): + with patch("mlx_lm.chat.mx.distributed.init", return_value=group), patch( + "sys.argv", ["chat.py"] + ): try: main() except SystemExit: pass - # Verify that apply_chat_template was called without system prompt - mock_tokenizer.apply_chat_template.assert_called() - call_args = mock_tokenizer.apply_chat_template.call_args[0][ - 0 - ] # First positional arg (messages) - - # Check that the messages contain only user message - self.assertEqual(len(call_args), 1) - self.assertEqual(call_args[0]["role"], "user") - self.assertEqual(call_args[0]["content"], "What is the weather?") + ui.prompt.assert_not_called() + mock_broadcast_string.assert_called_once_with("", group) + mock_stream_generate.assert_not_called() if __name__ == "__main__": diff --git a/tests/test_cli_ui.py b/tests/test_cli_ui.py new file mode 100644 index 000000000..76001ca27 --- /dev/null +++ b/tests/test_cli_ui.py @@ -0,0 +1,61 @@ +import argparse +import unittest +from unittest.mock import MagicMock, patch + +from rich.markdown import Markdown +from rich.panel import Panel + +from mlx_lm.cli_ui import ChatUI, make_console + + +class TestCliUI(unittest.TestCase): + + def setUp(self): + make_console.cache_clear() + + def tearDown(self): + make_console.cache_clear() + + def test_make_console_forces_terminal_interactive_mode(self): + console = make_console() + + self.assertTrue(console.is_terminal) + self.assertTrue(console.is_interactive) + + @patch("mlx_lm.cli_ui.Live") + def test_rank_zero_streams_with_live_markdown_buffer(self, mock_live): + live = MagicMock() + mock_live.return_value = live + args = argparse.Namespace(window_size=20, refresh_rate=7, max_tokens=128) + + ui = ChatUI(args, rank=0) + ui.stream_token("hello") + ui.stream_token(" **world**") + + mock_live.assert_called_once() + live.start.assert_called_once() + self.assertEqual(live.update.call_count, 2) + + initial_panel = mock_live.call_args.args[0] + self.assertIsInstance(initial_panel, Panel) + self.assertIsInstance(initial_panel.renderable, Markdown) + self.assertEqual(mock_live.call_args.kwargs["refresh_per_second"], 7) + + updated_panel = live.update.call_args.args[0] + self.assertIsInstance(updated_panel, Panel) + self.assertIsInstance(updated_panel.renderable, Markdown) + self.assertEqual(ui._response_text, "hello **world**") + + @patch("mlx_lm.cli_ui.Live") + def test_non_root_stream_does_not_start_live(self, mock_live): + args = argparse.Namespace(window_size=20, refresh_rate=7, max_tokens=128) + + ui = ChatUI(args, rank=1) + ui.stream_token("hidden") + + mock_live.assert_not_called() + self.assertEqual(ui._response_text, "") + + +if __name__ == "__main__": + unittest.main()