Skip to content

疑似大bug: RoformerForCasaulLM.prepare_inputs_for_generation: "past"参数应改名为"past_key_values" #45

@zjupeter

Description

@zjupeter

self, input_ids, past=None, attention_mask=None, token_type_ids=None, **model_kwargs

transformers模型的generate方法代码中,上文状态的参数名为past_key_values,不是past;
以transformers.GenerationMixin.sample方法为例,在其主循环while True中,在上一循环末尾会调用self._update_model_kwargs_for_generation 更新model_kwargs,此时更新进去的参数名为past_key_values;但到了下一循环调用self.prepare_inputs_for_generation时,又要从model_kwargs中取past参数,从而取不到,导致重复计算(基于更新了的input_ids重新算一遍上文状态),并且token_type_ids会一直取0,而不是预期的在生成第一个token时取0,此后一直取1。

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions