Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 57 additions & 5 deletions mlx_lm/chat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright © 2023-2024 Apple Inc.

import argparse
import readline # noqa: F401 # Enables terminal line editing/history on rank 0.

import mlx.core as mx

Expand All @@ -16,7 +17,33 @@
DEFAULT_XTC_THRESHOLD = 0.0
DEFAULT_SEED = 0
DEFAULT_MAX_TOKENS = 256
DEFAULT_RENDER_WINDOW_SIZE = 20
DEFAULT_REFRESH_RATE = 10
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
DEFAULT_SYSTEM_PROMPT = (
"You are a helpful assistant. Your responses are rendered in a terminal with "
"Markdown support. Feel free to use Markdown formatting when appropriate: "
"**bold**, *italic*, `inline code`, code blocks with syntax highlighting "
"(```language), bullet lists, numbered lists, and headers."
)


def broadcast_string(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why exactly this is needed?

value: str, group: mx.distributed.Group, src: int = 0
) -> str:
"""Broadcast a UTF-8 string from src to every rank in group."""
if group.size() == 1:
return value
if group.rank() == src:
data = mx.array(value.encode("utf-8"))
mx.eval(mx.distributed.all_sum(data.size, group=group))
mx.eval(mx.distributed.all_sum(data, group=group))
return value

size = mx.distributed.all_sum(0, group=group).item()
data = mx.distributed.all_sum(mx.zeros(size, dtype=mx.uint8), group=group)
mx.eval(data)
return bytes(data).decode("utf-8")


def setup_arg_parser():
Expand Down Expand Up @@ -53,8 +80,8 @@ def setup_arg_parser():
parser.add_argument(
"--xtc-threshold",
type=float,
default=0.0,
help="Thresold the probs of each next token candidate to be sampled by XTC",
default=DEFAULT_XTC_THRESHOLD,
help="Threshold the probs of each next token candidate to be sampled by XTC",
)
parser.add_argument(
"--seed",
Expand All @@ -78,13 +105,31 @@ def setup_arg_parser():
parser.add_argument(
"--system-prompt",
default=None,
help="System prompt to be used for the chat template",
help="System prompt to be used for the chat template "
"(replaces the default Markdown-aware prompt)",
)
parser.add_argument(
"--no-system-prompt",
action="store_true",
help="Disable the default system prompt entirely",
)
parser.add_argument(
"--pipeline",
action="store_true",
help="Use pipelining instead of tensor parallelism",
)
parser.add_argument(
"--window-size",
type=int,
default=DEFAULT_RENDER_WINDOW_SIZE,
help="The number of recent rendered lines to keep in the live panel",
)
parser.add_argument(
"--refresh-rate",
type=int,
default=DEFAULT_REFRESH_RATE,
help="The live panel refresh rate during generation",
)
return parser


Expand Down Expand Up @@ -115,7 +160,11 @@ def main():
with ChatUI(args, rank=rank) as ui:
prompt_cache = make_prompt_cache(model, args.max_kv_size)
while True:
query = ui.prompt()
query = ui.prompt() if rank == 0 else ""
query = broadcast_string(query, group).strip()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, here we communicate prompt other ranks, but I am not sure that I understand what is the goal.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The goal of that section is to make interactive chat work correctly in distributed mode.

What problem it solves:

  • In distributed execution, several processes/ranks are running at once.
  • We do not want every rank to ask the user for input.
  • We want only rank 0 to read the prompt from the terminal.
  • Then we want that same prompt to be sent to all the other ranks so they all generate from the exact same user message.

So the flow is:
User enter : "Hello"
Rank 0 reads "Hello"
Rank 1, rank 2, etc. read ""
broadcast_string(...) sends "Hello" from rank 0 to everyone after that, every rank has "Hello"
Without this, distributed chat would break in one of these ways:

  • every rank would try to read from stdin
  • non-root ranks would hang
  • different ranks could end up with inconsistent input

@nastya236 nastya236 Jun 9, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The prompt is read separately on each rank from its own stdin. Could you please clarify why non-root ranks would hang and why different ranks could end up with inconsistent input?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A simple failure may looks like this:

  • rank 0 reads "Hello"
  • rank 1 is still blocked in ui.prompt()
  • rank 0 starts generation or reaches a collective operation
  • rank 1 has not reached the same point yet

Now one rank is waiting for compute synchronization while the other is still waiting for terminal input

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you have this issue?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not append to me no.
I added this since a friend and I were speaking about it and decided that it was safer !


if not query:
continue
if query == "q":
ui.say_bye()
break
Expand All @@ -127,7 +176,10 @@ def main():
ui.say_help()
continue
messages = []
if args.system_prompt is not None:
if not args.no_system_prompt:
system_content = args.system_prompt or DEFAULT_SYSTEM_PROMPT
messages.append({"role": "system", "content": system_content})
elif args.system_prompt:
messages.append({"role": "system", "content": args.system_prompt})
messages.append({"role": "user", "content": query})
prompt = tokenizer.apply_chat_template(
Expand Down
56 changes: 54 additions & 2 deletions mlx_lm/cli_ui.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright © 2024 Apple Inc.

import re
import readline # noqa: F401 # Enables terminal line editing/history.
import shutil
import sys
from contextlib import contextmanager
Expand All @@ -9,6 +10,8 @@
import mlx.core as mx
from rich.box import ROUNDED
from rich.console import Console
from rich.live import Live
from rich.markdown import Markdown
from rich.panel import Panel
from rich.progress import Progress, ProgressColumn, TextColumn
from rich.text import Text
Expand Down Expand Up @@ -51,6 +54,8 @@ def make_console() -> Console:
theme=_make_theme(),
highlight=False,
color_system="truecolor",
force_terminal=True,
force_interactive=True,
width=_terminal_width(),
)

Expand Down Expand Up @@ -256,6 +261,10 @@ def __init__(self, args, rank: int = 0):
self._rank = rank
self._args = args
self._console = make_console()
self._response_text = ""
self._live = None
self._window_size = max(getattr(args, "window_size", 20), 1)
self._refresh_rate = max(getattr(args, "refresh_rate", 10), 1)

def __enter__(self):
if self._rank == 0:
Expand Down Expand Up @@ -294,16 +303,59 @@ def say_help(self):
if self._rank == 0:
print_chat_help(self._console)

def _display_text(self) -> str:
lines = self._response_text.splitlines(keepends=True)
if len(lines) > self._window_size:
return "".join(lines[-self._window_size :])
return self._response_text

def _ensure_live(self):
if self._rank != 0 or self._live is not None:
return
self._live = Live(
Panel(
Markdown(""),
title="[ui.accent]generating[/ui.accent]",
border_style="ui.border",
box=ROUNDED,
),
console=self._console,
refresh_per_second=self._refresh_rate,
transient=True,
)
self._live.start()

def stream_token(self, text: str):
rprint(text, flush=True, end="")
if self._rank != 0:
return
self._ensure_live()
self._response_text += text
self._live.update(
Panel(
Markdown(self._display_text()),
title="[ui.accent]generating[/ui.accent]",
border_style="ui.border",
box=ROUNDED,
)
)

def end_turn(self, response):
rprint() # newline after the streamed line
if self._live is not None:
self._live.stop()
self._live = None
if self._rank != 0 or response is None:
self._response_text = ""
return
self._console.print(Markdown(self._response_text))
if getattr(response, "finish_reason", None) == "length":
self._console.print(
f" [ui.warn]output truncated[/ui.warn] "
f"[ui.muted](max tokens: {self._args.max_tokens})[/ui.muted]"
)
self._console.print(
f" [ui.muted]{response.generation_tokens} tokens · "
f"{response.generation_tps:.1f} tok/s · "
f"prompt {response.prompt_tps:.1f} tok/s · "
f"peak {response.peak_memory:.2f} GB[/ui.muted]"
)
self._response_text = ""
2 changes: 1 addition & 1 deletion mlx_lm/sample_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def apply_top_k(
vocab_size = logprobs.shape[-1]
if not isinstance(top_k, int) or not (0 < top_k < vocab_size):
raise ValueError(
f"`top_k` has to be an integer in the (0, {vocab_size}] interval,"
f"`top_k` has to be an integer in the (0, {vocab_size}) interval,"
f" but is {top_k}."
)
mask_idx = mx.argpartition(-logprobs, kth=top_k - 1, axis=-1)[..., top_k:]
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"protobuf",
"pyyaml",
"jinja2",
"rich",
],
packages=[
"mlx_lm",
Expand Down
Loading