-
Notifications
You must be signed in to change notification settings - Fork 775
feat: enhance chat CLI with readline history, line editing, and distributed support #1382
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
8675d56
0383cdb
e887f88
aed9b59
d2ac7bd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| # Copyright © 2023-2024 Apple Inc. | ||
|
|
||
| import argparse | ||
| import readline # noqa: F401 # Enables terminal line editing/history on rank 0. | ||
|
|
||
| import mlx.core as mx | ||
|
|
||
|
|
@@ -16,7 +17,33 @@ | |
| DEFAULT_XTC_THRESHOLD = 0.0 | ||
| DEFAULT_SEED = 0 | ||
| DEFAULT_MAX_TOKENS = 256 | ||
| DEFAULT_RENDER_WINDOW_SIZE = 20 | ||
| DEFAULT_REFRESH_RATE = 10 | ||
| DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" | ||
| DEFAULT_SYSTEM_PROMPT = ( | ||
| "You are a helpful assistant. Your responses are rendered in a terminal with " | ||
| "Markdown support. Feel free to use Markdown formatting when appropriate: " | ||
| "**bold**, *italic*, `inline code`, code blocks with syntax highlighting " | ||
| "(```language), bullet lists, numbered lists, and headers." | ||
| ) | ||
|
|
||
|
|
||
| def broadcast_string( | ||
| value: str, group: mx.distributed.Group, src: int = 0 | ||
| ) -> str: | ||
| """Broadcast a UTF-8 string from src to every rank in group.""" | ||
| if group.size() == 1: | ||
| return value | ||
| if group.rank() == src: | ||
| data = mx.array(value.encode("utf-8")) | ||
| mx.eval(mx.distributed.all_sum(data.size, group=group)) | ||
| mx.eval(mx.distributed.all_sum(data, group=group)) | ||
| return value | ||
|
|
||
| size = mx.distributed.all_sum(0, group=group).item() | ||
| data = mx.distributed.all_sum(mx.zeros(size, dtype=mx.uint8), group=group) | ||
| mx.eval(data) | ||
| return bytes(data).decode("utf-8") | ||
|
|
||
|
|
||
| def setup_arg_parser(): | ||
|
|
@@ -53,8 +80,8 @@ def setup_arg_parser(): | |
| parser.add_argument( | ||
| "--xtc-threshold", | ||
| type=float, | ||
| default=0.0, | ||
| help="Thresold the probs of each next token candidate to be sampled by XTC", | ||
| default=DEFAULT_XTC_THRESHOLD, | ||
| help="Threshold the probs of each next token candidate to be sampled by XTC", | ||
| ) | ||
| parser.add_argument( | ||
| "--seed", | ||
|
|
@@ -78,13 +105,31 @@ def setup_arg_parser(): | |
| parser.add_argument( | ||
| "--system-prompt", | ||
| default=None, | ||
| help="System prompt to be used for the chat template", | ||
| help="System prompt to be used for the chat template " | ||
| "(replaces the default Markdown-aware prompt)", | ||
| ) | ||
| parser.add_argument( | ||
| "--no-system-prompt", | ||
| action="store_true", | ||
| help="Disable the default system prompt entirely", | ||
| ) | ||
| parser.add_argument( | ||
| "--pipeline", | ||
| action="store_true", | ||
| help="Use pipelining instead of tensor parallelism", | ||
| ) | ||
| parser.add_argument( | ||
| "--window-size", | ||
| type=int, | ||
| default=DEFAULT_RENDER_WINDOW_SIZE, | ||
| help="The number of recent rendered lines to keep in the live panel", | ||
| ) | ||
| parser.add_argument( | ||
| "--refresh-rate", | ||
| type=int, | ||
| default=DEFAULT_REFRESH_RATE, | ||
| help="The live panel refresh rate during generation", | ||
| ) | ||
| return parser | ||
|
|
||
|
|
||
|
|
@@ -115,7 +160,11 @@ def main(): | |
| with ChatUI(args, rank=rank) as ui: | ||
| prompt_cache = make_prompt_cache(model, args.max_kv_size) | ||
| while True: | ||
| query = ui.prompt() | ||
| query = ui.prompt() if rank == 0 else "" | ||
| query = broadcast_string(query, group).strip() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I understand correctly, here we communicate prompt other ranks, but I am not sure that I understand what is the goal.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The goal of that section is to make interactive chat work correctly in distributed mode. What problem it solves:
So the flow is:
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The prompt is read separately on each rank from its own stdin. Could you please clarify why non-root ranks would hang and why different ranks could end up with inconsistent input?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A simple failure may looks like this:
Now one rank is waiting for compute synchronization while the other is still waiting for terminal input
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did you have this issue?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It does not append to me no. |
||
|
|
||
| if not query: | ||
| continue | ||
| if query == "q": | ||
| ui.say_bye() | ||
| break | ||
|
|
@@ -127,7 +176,10 @@ def main(): | |
| ui.say_help() | ||
| continue | ||
| messages = [] | ||
| if args.system_prompt is not None: | ||
| if not args.no_system_prompt: | ||
| system_content = args.system_prompt or DEFAULT_SYSTEM_PROMPT | ||
| messages.append({"role": "system", "content": system_content}) | ||
| elif args.system_prompt: | ||
| messages.append({"role": "system", "content": args.system_prompt}) | ||
| messages.append({"role": "user", "content": query}) | ||
| prompt = tokenizer.apply_chat_template( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,6 +31,7 @@ | |
| "protobuf", | ||
| "pyyaml", | ||
| "jinja2", | ||
| "rich", | ||
| ], | ||
| packages=[ | ||
| "mlx_lm", | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why exactly this is needed?