Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions mlx_engine/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
StopPromptProcessing,
)
from mlx_engine.utils.generation_helpers import (
setup_repetition_penalty,
setup_logits_processors,
create_sampler,
validate_top_logprobs,
Expand Down Expand Up @@ -320,6 +319,10 @@ def create_generator(
discourage repetition
repetition_context_size (Optional[int]): Number of previous tokens to consider for
repetition penalty. Defaults to 20
presence_penalty (Optional[float]): Additive penalty applied to tokens present in the
context window, reducing token repetition
presence_context_size (Optional[int]): Number of recent tokens to consider for the
presence penalty. Defaults to 20
temp (Optional[float]): Temperature for sampling. Higher values increase randomness
top_p (Optional[float]): Top-p (nucleus) sampling parameter
top_k (Optional[int]): Top-k sampling parameter
Expand Down Expand Up @@ -402,6 +405,8 @@ def _sequential_generation(
top_logprobs: Optional[int] = None,
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = 20,
presence_penalty: Optional[float] = None,
presence_context_size: Optional[int] = 20,
temp: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
Expand Down Expand Up @@ -437,11 +442,6 @@ def _sequential_generation(
if value is not None:
generate_args[attr] = value

# Set up repetition penalty
repetition_penalty_kwargs = setup_repetition_penalty(
repetition_penalty, repetition_context_size
)

# Set up speculative decoding
draft_model = determine_draft_model_for_generation(
model_kit, speculative_decoding_toggle
Expand Down Expand Up @@ -470,7 +470,9 @@ def _sequential_generation(
# Setup logits processors
logits_processors = setup_logits_processors(
repetition_penalty,
repetition_penalty_kwargs,
repetition_context_size,
presence_penalty,
presence_context_size,
prompt_tokens,
input_tokens,
None,
Expand Down Expand Up @@ -616,6 +618,8 @@ def _batched_generation(
top_logprobs: Optional[int] = None,
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = 20,
presence_penalty: Optional[float] = None,
presence_context_size: Optional[int] = 20,
temp: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
Expand All @@ -639,16 +643,13 @@ def _batched_generation(
if prompt_progress_reporter is None:
prompt_progress_reporter = DefaultPromptProgressReporter()

# Set up repetition penalty
repetition_penalty_kwargs = setup_repetition_penalty(
repetition_penalty, repetition_context_size
)

# Setup logits processors
tokenizer = model_kit.tokenizer
logits_processors = setup_logits_processors(
repetition_penalty,
repetition_penalty_kwargs,
repetition_context_size,
presence_penalty,
presence_context_size,
prompt_tokens,
input_tokens,
None,
Expand Down
50 changes: 0 additions & 50 deletions mlx_engine/processors/repetition_penalty_processor.py

This file was deleted.

28 changes: 28 additions & 0 deletions mlx_engine/processors/token_penalty_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Token penalty processor with KV cache awareness."""

from collections.abc import Callable

import mlx.core as mx


class TokenPenaltyProcessor:
# Prepends cached prefix tokens so the penalty window spans the full context,
# not just the tokens generated in the current turn.

def __init__(
self,
penalty_fn: Callable[[mx.array, mx.array], mx.array],
token_history: list[int],
context_size: int,
):
self.token_history = token_history
self.context_size = context_size
self._penalty_fn = penalty_fn

def __call__(self, tokens: mx.array, logits: mx.array) -> mx.array:
num_to_prepend = max(self.context_size - len(tokens), 0)
historical = (
self.token_history[-num_to_prepend:] if num_to_prepend > 0 else []
)
all_tokens = mx.concat([mx.array(historical, dtype=mx.int64), tokens])
return self._penalty_fn(all_tokens, logits)
51 changes: 27 additions & 24 deletions mlx_engine/utils/generation_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
"""

from typing import Optional, List, Tuple

from mlx_lm.sample_utils import make_presence_penalty, make_repetition_penalty, make_sampler
from mlx_lm.tokenizer_utils import TokenizerWrapper
from mlx_lm.sample_utils import make_sampler
from mlx_engine.processors.repetition_penalty_processor import (
RepetitionPenaltyProcessor,
)

from mlx_engine.processors.token_penalty_processor import TokenPenaltyProcessor
from mlx_engine.stop_string_processor import (
StopStringProcessor,
StopStringProcessorResult,
Expand All @@ -22,38 +22,41 @@
MAX_TOP_LOGPROBS = 10


def setup_repetition_penalty(
repetition_penalty: Optional[float], repetition_context_size: Optional[int]
) -> dict:
repetition_penalty_kwargs = {}
if repetition_penalty is not None:
repetition_penalty_kwargs["repetition_penalty"] = repetition_penalty
if repetition_context_size is not None:
repetition_penalty_kwargs["repetition_context_size"] = (
repetition_context_size
)
return repetition_penalty_kwargs


def setup_logits_processors(
repetition_penalty: Optional[float],
repetition_penalty_kwargs: dict,
repetition_context_size: Optional[int],
presence_penalty: Optional[float],
presence_context_size: Optional[int],
prompt_tokens: List[int],
input_tokens: List[int],
json_schema: Optional[str],
tokenizer: TokenizerWrapper,
) -> List:
logits_processors = []
cached_tokens = []

if repetition_penalty and repetition_penalty != 0.0:
if repetition_penalty or presence_penalty:
cached_tokens = (
prompt_tokens[: -len(input_tokens)]
if len(input_tokens) > 0
else prompt_tokens
prompt_tokens[: -len(input_tokens)] if len(input_tokens) > 0 else prompt_tokens
)

if repetition_penalty and repetition_penalty != 0.0:
context_size = repetition_context_size if repetition_context_size is not None else 20
logits_processors.append(
TokenPenaltyProcessor(
make_repetition_penalty(repetition_penalty, context_size),
cached_tokens,
context_size,
)
)

if presence_penalty and presence_penalty != 0.0:
context_size = presence_context_size if presence_context_size is not None else 20
logits_processors.append(
RepetitionPenaltyProcessor(
token_history=cached_tokens, **repetition_penalty_kwargs
TokenPenaltyProcessor(
make_presence_penalty(presence_penalty, context_size),
cached_tokens,
context_size,
)
)

Expand Down
24 changes: 9 additions & 15 deletions tests/processors/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,18 @@ For example, we can add a `DumpLogitsProcessor` that writes the logits on each g
```diff
--- a/mlx_engine/generate.py
+++ b/mlx_engine/generate.py
@@ -12,6 +12,9 @@ from mlx_engine.processors.outlines_logits_processor import OutlinesLogitsProces
from mlx_engine.processors.repetition_penalty_processor import (
RepetitionPenaltyProcessor,
)
@@ -51,6 +51,9 @@ from mlx_engine.utils.generation_helpers import (
+from tests.processors.dump_logits_processor import (
+ DumpLogitsProcessor,
+)
from mlx_engine.utils.token import Token
from mlx_engine.utils.eot_tokens import get_eot_token_ids
from mlx_engine.utils.top_logprobs import summarize_top_logprobs
@@ -236,6 +239,9 @@ def create_generator(
token_history=cached_tokens, **repetition_penalty_kwargs
)
from mlx_engine.utils.generation_helpers import (
setup_logits_processors,
create_sampler,
@@ -480,6 +480,7 @@ def _sequential_generation(
)
+ generate_args["logits_processors"].append(
+ DumpLogitsProcessor(model_kit.tokenizer.vocab, Path("logits-dump"))
+ )

# Set up sampler
generate_args["sampler"] = make_sampler(
+ logits_processors.append(DumpLogitsProcessor(model_kit.tokenizer.vocab, Path("logits-dump")))
# Set up sampler
generate_args["sampler"] = create_sampler(
temp, top_p, min_p, min_tokens_to_keep, top_k
```
33 changes: 33 additions & 0 deletions tests/test_text_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,39 @@ def generate() -> None:
"The quick brown fox jumped over the lazy dog.", generated_text
)

def test_presence_penalty_applies(self):
model_path = model_getter("lmstudio-community/Qwen2.5-0.5B-Instruct-MLX-8bit")
model_kit = load_model(model_path=model_path, max_kv_size=4096)
self.addCleanup(lambda *_: unload(model_kit))
prompt = """<|im_start|>user
The quick brown fox jumped over the lazy dog. The quick brown fox jumped over the lazy dog. The quick brown fox jumped over the lazy dog. Repeat what I said.
<|im_end|>\n<|im_start|>assistant\n"""
prompt_tokens = tokenize(model_kit, prompt)
generated_text = ""

def generate() -> None:
nonlocal generated_text
for result in create_generator(
model_kit=model_kit,
prompt_tokens=prompt_tokens,
presence_penalty=2.0, # strong penalty to prevent repetition
presence_context_size=64,
seed=0,
max_tokens=20,
temp=0.0,
):
print(result.text, end="", flush=True)
generated_text += result.text
if result.stop_condition:
break
print("\n", flush=True)

generate()
self.assertGreater(len(generated_text), 0, "Model failed to generate any text")
self.assertNotIn(
"The quick brown fox jumped over the lazy dog.", generated_text
)

def test_prompt_caching_happy_path_qwen2_5(self):
model_path = model_getter("lmstudio-community/Qwen2.5-0.5B-Instruct-MLX-8bit")
model_kit = load_model(model_path=model_path, max_kv_size=20000)
Expand Down
Loading