diff --git a/llava/model/llava_arch.py b/llava/model/llava_arch.py index ae789009..3c25a9fe 100755 --- a/llava/model/llava_arch.py +++ b/llava/model/llava_arch.py @@ -416,7 +416,7 @@ def _embed( # This is a workaround to make sure the dummy embeddings are consumed while media_embeds.get("dummy"): dummy_embed = media_embeds["dummy"].popleft() - text_embeds += torch.sum(dummy_embed) * 0 + text_embeds = text_embeds + (torch.sum(dummy_embed) * 0) # Remove padding batch_size = labels.shape[0]