diff --git a/mlx_engine/generate.py b/mlx_engine/generate.py index 51e40fa9..8a794c48 100644 --- a/mlx_engine/generate.py +++ b/mlx_engine/generate.py @@ -49,7 +49,6 @@ StopPromptProcessing, ) from mlx_engine.utils.generation_helpers import ( - setup_repetition_penalty, setup_logits_processors, create_sampler, validate_top_logprobs, @@ -320,6 +319,10 @@ def create_generator( discourage repetition repetition_context_size (Optional[int]): Number of previous tokens to consider for repetition penalty. Defaults to 20 + presence_penalty (Optional[float]): Additive penalty applied to tokens present in the + context window, reducing token repetition + presence_context_size (Optional[int]): Number of recent tokens to consider for the + presence penalty. Defaults to 20 temp (Optional[float]): Temperature for sampling. Higher values increase randomness top_p (Optional[float]): Top-p (nucleus) sampling parameter top_k (Optional[int]): Top-k sampling parameter @@ -402,6 +405,8 @@ def _sequential_generation( top_logprobs: Optional[int] = None, repetition_penalty: Optional[float] = None, repetition_context_size: Optional[int] = 20, + presence_penalty: Optional[float] = None, + presence_context_size: Optional[int] = 20, temp: Optional[float] = None, top_p: Optional[float] = None, top_k: Optional[int] = None, @@ -437,11 +442,6 @@ def _sequential_generation( if value is not None: generate_args[attr] = value - # Set up repetition penalty - repetition_penalty_kwargs = setup_repetition_penalty( - repetition_penalty, repetition_context_size - ) - # Set up speculative decoding draft_model = determine_draft_model_for_generation( model_kit, speculative_decoding_toggle @@ -470,7 +470,9 @@ def _sequential_generation( # Setup logits processors logits_processors = setup_logits_processors( repetition_penalty, - repetition_penalty_kwargs, + repetition_context_size, + presence_penalty, + presence_context_size, prompt_tokens, input_tokens, None, @@ -616,6 +618,8 @@ def _batched_generation( top_logprobs: Optional[int] = None, repetition_penalty: Optional[float] = None, repetition_context_size: Optional[int] = 20, + presence_penalty: Optional[float] = None, + presence_context_size: Optional[int] = 20, temp: Optional[float] = None, top_p: Optional[float] = None, top_k: Optional[int] = None, @@ -639,16 +643,13 @@ def _batched_generation( if prompt_progress_reporter is None: prompt_progress_reporter = DefaultPromptProgressReporter() - # Set up repetition penalty - repetition_penalty_kwargs = setup_repetition_penalty( - repetition_penalty, repetition_context_size - ) - # Setup logits processors tokenizer = model_kit.tokenizer logits_processors = setup_logits_processors( repetition_penalty, - repetition_penalty_kwargs, + repetition_context_size, + presence_penalty, + presence_context_size, prompt_tokens, input_tokens, None, diff --git a/mlx_engine/processors/repetition_penalty_processor.py b/mlx_engine/processors/repetition_penalty_processor.py deleted file mode 100644 index 425cf963..00000000 --- a/mlx_engine/processors/repetition_penalty_processor.py +++ /dev/null @@ -1,50 +0,0 @@ -import mlx.core as mx -from mlx_lm.sample_utils import make_repetition_penalty - -""" -Wrapper for the standard mlx-lm repetition penalty processor -ref: https://github.com/ml-explore/mlx-lm/blob/69195f8632869d35306d085de7dc4e7d6954baac/mlx_lm/sample_utils.py#L245-L255 - -This wrapper enables the repetition penalty processor to take into account the tokens that have already been cached, -without the need for recomputing the logits for those tokens. -""" - - -class RepetitionPenaltyProcessor: - def __init__( - self, - token_history: list[int], - repetition_penalty: float, - repetition_context_size: int, - ): - self.token_history = token_history - self.repetition_context_size = repetition_context_size - self.repetition_penalty_function = make_repetition_penalty( - repetition_penalty, repetition_context_size - ) - - def __call__(self, tokens: mx.array, logits: mx.array) -> mx.array: - """ - Apply repetition penalty to the logits, accounting for tokens that have already been processed within - the same prediction. - - Args: - tokens: The tokens to be processed. - logits: The logits to be processed. - """ - # append historical tokens s.t. repetition penalty accounts tokens that have already been processed in this gen - num_tokens_to_prepend_from_history = max( - self.repetition_context_size - len(tokens), 0 - ) - historical_tokens = ( - self.token_history[-num_tokens_to_prepend_from_history:] - if num_tokens_to_prepend_from_history > 0 - else [] - ) - historical_tokens_mx = mx.array( - historical_tokens, - dtype=mx.int64, - ) - all_tokens_to_consider = mx.concat([historical_tokens_mx, tokens]) - result = self.repetition_penalty_function(all_tokens_to_consider, logits) - return result diff --git a/mlx_engine/processors/token_penalty_processor.py b/mlx_engine/processors/token_penalty_processor.py new file mode 100644 index 00000000..c0aa6347 --- /dev/null +++ b/mlx_engine/processors/token_penalty_processor.py @@ -0,0 +1,28 @@ +"""Token penalty processor with KV cache awareness.""" + +from collections.abc import Callable + +import mlx.core as mx + + +class TokenPenaltyProcessor: + # Prepends cached prefix tokens so the penalty window spans the full context, + # not just the tokens generated in the current turn. + + def __init__( + self, + penalty_fn: Callable[[mx.array, mx.array], mx.array], + token_history: list[int], + context_size: int, + ): + self.token_history = token_history + self.context_size = context_size + self._penalty_fn = penalty_fn + + def __call__(self, tokens: mx.array, logits: mx.array) -> mx.array: + num_to_prepend = max(self.context_size - len(tokens), 0) + historical = ( + self.token_history[-num_to_prepend:] if num_to_prepend > 0 else [] + ) + all_tokens = mx.concat([mx.array(historical, dtype=mx.int64), tokens]) + return self._penalty_fn(all_tokens, logits) diff --git a/mlx_engine/utils/generation_helpers.py b/mlx_engine/utils/generation_helpers.py index 824147dd..a30c53a0 100644 --- a/mlx_engine/utils/generation_helpers.py +++ b/mlx_engine/utils/generation_helpers.py @@ -6,11 +6,11 @@ """ from typing import Optional, List, Tuple + +from mlx_lm.sample_utils import make_presence_penalty, make_repetition_penalty, make_sampler from mlx_lm.tokenizer_utils import TokenizerWrapper -from mlx_lm.sample_utils import make_sampler -from mlx_engine.processors.repetition_penalty_processor import ( - RepetitionPenaltyProcessor, -) + +from mlx_engine.processors.token_penalty_processor import TokenPenaltyProcessor from mlx_engine.stop_string_processor import ( StopStringProcessor, StopStringProcessorResult, @@ -22,38 +22,41 @@ MAX_TOP_LOGPROBS = 10 -def setup_repetition_penalty( - repetition_penalty: Optional[float], repetition_context_size: Optional[int] -) -> dict: - repetition_penalty_kwargs = {} - if repetition_penalty is not None: - repetition_penalty_kwargs["repetition_penalty"] = repetition_penalty - if repetition_context_size is not None: - repetition_penalty_kwargs["repetition_context_size"] = ( - repetition_context_size - ) - return repetition_penalty_kwargs - - def setup_logits_processors( repetition_penalty: Optional[float], - repetition_penalty_kwargs: dict, + repetition_context_size: Optional[int], + presence_penalty: Optional[float], + presence_context_size: Optional[int], prompt_tokens: List[int], input_tokens: List[int], json_schema: Optional[str], tokenizer: TokenizerWrapper, ) -> List: logits_processors = [] + cached_tokens = [] - if repetition_penalty and repetition_penalty != 0.0: + if repetition_penalty or presence_penalty: cached_tokens = ( - prompt_tokens[: -len(input_tokens)] - if len(input_tokens) > 0 - else prompt_tokens + prompt_tokens[: -len(input_tokens)] if len(input_tokens) > 0 else prompt_tokens ) + + if repetition_penalty and repetition_penalty != 0.0: + context_size = repetition_context_size if repetition_context_size is not None else 20 + logits_processors.append( + TokenPenaltyProcessor( + make_repetition_penalty(repetition_penalty, context_size), + cached_tokens, + context_size, + ) + ) + + if presence_penalty and presence_penalty != 0.0: + context_size = presence_context_size if presence_context_size is not None else 20 logits_processors.append( - RepetitionPenaltyProcessor( - token_history=cached_tokens, **repetition_penalty_kwargs + TokenPenaltyProcessor( + make_presence_penalty(presence_penalty, context_size), + cached_tokens, + context_size, ) ) diff --git a/tests/processors/README.md b/tests/processors/README.md index 42ba80b6..3f9f64b4 100644 --- a/tests/processors/README.md +++ b/tests/processors/README.md @@ -5,24 +5,18 @@ For example, we can add a `DumpLogitsProcessor` that writes the logits on each g ```diff --- a/mlx_engine/generate.py +++ b/mlx_engine/generate.py -@@ -12,6 +12,9 @@ from mlx_engine.processors.outlines_logits_processor import OutlinesLogitsProces - from mlx_engine.processors.repetition_penalty_processor import ( - RepetitionPenaltyProcessor, - ) +@@ -51,6 +51,9 @@ from mlx_engine.utils.generation_helpers import ( +from tests.processors.dump_logits_processor import ( + DumpLogitsProcessor, +) - from mlx_engine.utils.token import Token - from mlx_engine.utils.eot_tokens import get_eot_token_ids - from mlx_engine.utils.top_logprobs import summarize_top_logprobs -@@ -236,6 +239,9 @@ def create_generator( - token_history=cached_tokens, **repetition_penalty_kwargs - ) + from mlx_engine.utils.generation_helpers import ( + setup_logits_processors, + create_sampler, +@@ -480,6 +480,7 @@ def _sequential_generation( ) -+ generate_args["logits_processors"].append( -+ DumpLogitsProcessor(model_kit.tokenizer.vocab, Path("logits-dump")) -+ ) - # Set up sampler - generate_args["sampler"] = make_sampler( ++ logits_processors.append(DumpLogitsProcessor(model_kit.tokenizer.vocab, Path("logits-dump"))) + # Set up sampler + generate_args["sampler"] = create_sampler( + temp, top_p, min_p, min_tokens_to_keep, top_k ``` diff --git a/tests/test_text_models.py b/tests/test_text_models.py index 27926cdb..d4753fd7 100644 --- a/tests/test_text_models.py +++ b/tests/test_text_models.py @@ -61,6 +61,39 @@ def generate() -> None: "The quick brown fox jumped over the lazy dog.", generated_text ) + def test_presence_penalty_applies(self): + model_path = model_getter("lmstudio-community/Qwen2.5-0.5B-Instruct-MLX-8bit") + model_kit = load_model(model_path=model_path, max_kv_size=4096) + self.addCleanup(lambda *_: unload(model_kit)) + prompt = """<|im_start|>user +The quick brown fox jumped over the lazy dog. The quick brown fox jumped over the lazy dog. The quick brown fox jumped over the lazy dog. Repeat what I said. +<|im_end|>\n<|im_start|>assistant\n""" + prompt_tokens = tokenize(model_kit, prompt) + generated_text = "" + + def generate() -> None: + nonlocal generated_text + for result in create_generator( + model_kit=model_kit, + prompt_tokens=prompt_tokens, + presence_penalty=2.0, # strong penalty to prevent repetition + presence_context_size=64, + seed=0, + max_tokens=20, + temp=0.0, + ): + print(result.text, end="", flush=True) + generated_text += result.text + if result.stop_condition: + break + print("\n", flush=True) + + generate() + self.assertGreater(len(generated_text), 0, "Model failed to generate any text") + self.assertNotIn( + "The quick brown fox jumped over the lazy dog.", generated_text + ) + def test_prompt_caching_happy_path_qwen2_5(self): model_path = model_getter("lmstudio-community/Qwen2.5-0.5B-Instruct-MLX-8bit") model_kit = load_model(model_path=model_path, max_kv_size=20000)