Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 57 additions & 1 deletion tests/test_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import transformers
from packaging.version import Version
from transformers import AutoModelForCausalLM

from trl.models.utils import disable_gradient_checkpointing
from trl.models.utils import _override_model_generation_config, disable_gradient_checkpointing


class TestDisableGradientCheckpointing:
Expand All @@ -32,3 +35,56 @@ def test_when_enabled(self):
with disable_gradient_checkpointing(model):
assert model.is_gradient_checkpointing is False
assert model.is_gradient_checkpointing is True


class TestOverrideModelGenerationConfig:
"""Tests for _override_model_generation_config (workaround for transformers#42762).

Qwen2.5 models ship with a model-specific generation_config that sets temperature
close to 0 (near-greedy). On transformers < 5.0, passing generation_config to
model.generate() merges model defaults on top, silently overriding training kwargs
such as temperature=1.0. _override_model_generation_config temporarily replaces the
model's generation_config with the training kwargs so the merge is a no-op.
"""

def test_overrides_model_generation_config_during_context(self):
"""Inside the context, model.generation_config reflects the training kwargs.

Only applies to transformers < 5.0; on v5+ the upstream bug is fixed and
the context manager is a no-op.
"""
if Version(transformers.__version__) >= Version("5.0.0"):
pytest.skip("transformers >= 5.0 fixes the upstream bug; context manager is a no-op")
model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
training_kwargs = {"temperature": 1.0, "do_sample": True}
with _override_model_generation_config(model, generation_kwargs=training_kwargs):
assert model.generation_config.temperature == 1.0
assert model.generation_config.do_sample is True

def test_restores_original_generation_config_after_context(self):
"""After the context, model.generation_config is restored to its original value."""
model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
original_temperature = model.generation_config.temperature
original_config_id = id(model.generation_config)
training_kwargs = {"temperature": 1.0, "do_sample": True}
with _override_model_generation_config(model, generation_kwargs=training_kwargs):
pass
assert model.generation_config.temperature == original_temperature
assert id(model.generation_config) == original_config_id

def test_no_op_when_generation_kwargs_is_none(self):
"""When generation_kwargs is None, the context manager yields without modifying the model."""
model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
original_config_id = id(model.generation_config)
with _override_model_generation_config(model, generation_kwargs=None):
assert id(model.generation_config) == original_config_id

def test_no_op_on_transformers_v5(self):
"""On transformers >= 5.0 the upstream bug is fixed; the context manager is a no-op."""
if Version(transformers.__version__) < Version("5.0.0"):
return # only meaningful on transformers v5+
model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
original_config_id = id(model.generation_config)
training_kwargs = {"temperature": 1.0}
with _override_model_generation_config(model, generation_kwargs=training_kwargs):
assert id(model.generation_config) == original_config_id
5 changes: 4 additions & 1 deletion trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1362,7 +1362,10 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields):
with (
profiling_context(self, "transformers.generate_batch"),
unwrap_model_for_generation(
self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
self.model_wrapped,
self.accelerator,
gather_deepspeed3_params=self.args.ds3_gather_for_generation,
generation_kwargs=self.generation_kwargs, # Override model.generation_config with generation_kwargs to fix transformers#42762
) as unwrapped_model,
torch.no_grad(),
FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(),
Expand Down
5 changes: 4 additions & 1 deletion trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,7 +981,10 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields):
with (
profiling_context(self, "transformers.generate_batch"),
unwrap_model_for_generation(
self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
self.model_wrapped,
self.accelerator,
gather_deepspeed3_params=self.args.ds3_gather_for_generation,
generation_kwargs=self.generation_kwargs, # Override model.generation_config with generation_kwargs to fix transformers#42762
) as unwrapped_model,
torch.no_grad(),
FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(),
Expand Down
Loading