|
self, input_ids, past=None, attention_mask=None, token_type_ids=None, **model_kwargs |
transformers模型的generate方法代码中,上文状态的参数名为past_key_values,不是past;
以transformers.GenerationMixin.sample方法为例,在其主循环while True中,在上一循环末尾会调用self._update_model_kwargs_for_generation 更新model_kwargs,此时更新进去的参数名为past_key_values;但到了下一循环调用self.prepare_inputs_for_generation时,又要从model_kwargs中取past参数,从而取不到,导致重复计算(基于更新了的input_ids重新算一遍上文状态),并且token_type_ids会一直取0,而不是预期的在生成第一个token时取0,此后一直取1。
RoFormer_pytorch/src/roformer/modeling_roformer.py
Line 1465 in 447aa8f
transformers模型的generate方法代码中,上文状态的参数名为past_key_values,不是past;
以transformers.GenerationMixin.sample方法为例,在其主循环while True中,在上一循环末尾会调用self._update_model_kwargs_for_generation 更新model_kwargs,此时更新进去的参数名为past_key_values;但到了下一循环调用self.prepare_inputs_for_generation时,又要从model_kwargs中取past参数,从而取不到,导致重复计算(基于更新了的input_ids重新算一遍上文状态),并且token_type_ids会一直取0,而不是预期的在生成第一个token时取0,此后一直取1。