diff --git a/tests/test_model_utils.py b/tests/test_model_utils.py index 7e3fde991dd..bbaf87a11f4 100644 --- a/tests/test_model_utils.py +++ b/tests/test_model_utils.py @@ -12,9 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest +import transformers +from packaging.version import Version from transformers import AutoModelForCausalLM -from trl.models.utils import disable_gradient_checkpointing +from trl.models.utils import _override_model_generation_config, disable_gradient_checkpointing class TestDisableGradientCheckpointing: @@ -32,3 +35,56 @@ def test_when_enabled(self): with disable_gradient_checkpointing(model): assert model.is_gradient_checkpointing is False assert model.is_gradient_checkpointing is True + + +class TestOverrideModelGenerationConfig: + """Tests for _override_model_generation_config (workaround for transformers#42762). + + Qwen2.5 models ship with a model-specific generation_config that sets temperature + close to 0 (near-greedy). On transformers < 5.0, passing generation_config to + model.generate() merges model defaults on top, silently overriding training kwargs + such as temperature=1.0. _override_model_generation_config temporarily replaces the + model's generation_config with the training kwargs so the merge is a no-op. + """ + + def test_overrides_model_generation_config_during_context(self): + """Inside the context, model.generation_config reflects the training kwargs. + + Only applies to transformers < 5.0; on v5+ the upstream bug is fixed and + the context manager is a no-op. + """ + if Version(transformers.__version__) >= Version("5.0.0"): + pytest.skip("transformers >= 5.0 fixes the upstream bug; context manager is a no-op") + model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + training_kwargs = {"temperature": 1.0, "do_sample": True} + with _override_model_generation_config(model, generation_kwargs=training_kwargs): + assert model.generation_config.temperature == 1.0 + assert model.generation_config.do_sample is True + + def test_restores_original_generation_config_after_context(self): + """After the context, model.generation_config is restored to its original value.""" + model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + original_temperature = model.generation_config.temperature + original_config_id = id(model.generation_config) + training_kwargs = {"temperature": 1.0, "do_sample": True} + with _override_model_generation_config(model, generation_kwargs=training_kwargs): + pass + assert model.generation_config.temperature == original_temperature + assert id(model.generation_config) == original_config_id + + def test_no_op_when_generation_kwargs_is_none(self): + """When generation_kwargs is None, the context manager yields without modifying the model.""" + model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + original_config_id = id(model.generation_config) + with _override_model_generation_config(model, generation_kwargs=None): + assert id(model.generation_config) == original_config_id + + def test_no_op_on_transformers_v5(self): + """On transformers >= 5.0 the upstream bug is fixed; the context manager is a no-op.""" + if Version(transformers.__version__) < Version("5.0.0"): + return # only meaningful on transformers v5+ + model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + original_config_id = id(model.generation_config) + training_kwargs = {"temperature": 1.0} + with _override_model_generation_config(model, generation_kwargs=training_kwargs): + assert id(model.generation_config) == original_config_id diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 67c1c0ecc6b..6348701f559 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1362,7 +1362,10 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields): with ( profiling_context(self, "transformers.generate_batch"), unwrap_model_for_generation( - self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation + self.model_wrapped, + self.accelerator, + gather_deepspeed3_params=self.args.ds3_gather_for_generation, + generation_kwargs=self.generation_kwargs, # Override model.generation_config with generation_kwargs to fix transformers#42762 ) as unwrapped_model, torch.no_grad(), FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 0613241dc6a..50d7e844fcb 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -981,7 +981,10 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields): with ( profiling_context(self, "transformers.generate_batch"), unwrap_model_for_generation( - self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation + self.model_wrapped, + self.accelerator, + gather_deepspeed3_params=self.args.ds3_gather_for_generation, + generation_kwargs=self.generation_kwargs, # Override model.generation_config with generation_kwargs to fix transformers#42762 ) as unwrapped_model, torch.no_grad(), FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(),