Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 12 additions & 60 deletions mbridge/models/qwen3_vl/model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import logging
from typing import Optional

import torch
from megatron.core import InferenceParams, mpu, tensor_parallel
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.extensions.transformer_engine import get_thd_partitioned_indices

from mbridge.core.util import (
AllGatherVisionEmbeddings,
Expand Down Expand Up @@ -209,18 +207,18 @@ def forward(
position_ids: torch.Tensor = None, # can set at dataset
attention_mask: torch.Tensor = None,
labels: torch.Tensor = None,
inference_params: Optional[InferenceParams] = None,
packed_seq_params: Optional[PackedSeqParams] = None,
extra_block_kwargs: Optional[dict] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.Tensor] = None,
video_grid_thw: Optional[torch.Tensor] = None,
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None,
pixel_values: torch.Tensor = None,
pixel_values_videos: torch.Tensor = None,
image_grid_thw: torch.Tensor = None,
video_grid_thw: torch.Tensor = None,
# can set at dataset
image_input_mask: torch.Tensor = None,
video_input_mask: torch.Tensor = None,
cp_img_num: Optional[list[int]] = None,
images_padded: Optional[list[bool]] = None,
cp_img_num: list[int] = None,
images_padded: list[bool] = None,
**kwargs,
) -> torch.Tensor:
"""Forward function of the Qwen3VL model.
Expand Down Expand Up @@ -250,14 +248,7 @@ def forward(
vision_data = None
vision_mask = None
deepstack_feature_lists = None
if packed_seq_params is not None and packed_seq_params.cp_group is not None:
cp_group = packed_seq_params.cp_group
else:
cp_group = mpu.get_context_parallel_group()
cp_size = cp_group.size()
cp_rank = cp_group.rank()

self.language_model.rotary_pos_emb.is_thd_format = packed_seq_params is not None
cp_size = mpu.get_context_parallel_world_size()

if self.pre_process:
# can reorganize_inputs at dataset
Expand Down Expand Up @@ -293,7 +284,6 @@ def forward(
self.square_merge_size,
cp_img_num,
images_padded,
group=cp_group,
)
)
vision_grid_thw = collapse_thw(vision_grid_thw)
Expand Down Expand Up @@ -321,13 +311,11 @@ def forward(
vision_embeds = AllGatherVisionEmbeddings.apply(
vision_embeds,
seqlen_on_cp_ranks,
cp_group,
)
for i in range(len(deepstack_feature_lists)):
deepstack_feature_lists[i] = AllGatherVisionEmbeddings.apply(
deepstack_feature_lists[i],
seqlen_on_cp_ranks,
cp_group,
)

combined_embeddings = self.language_model.embedding(
Expand All @@ -348,43 +336,7 @@ def forward(
combined_embeddings = split_data_cp_rank(
combined_embeddings, cp_size, 0
)

# packed_seq_params is not None and attention_mask is None:
# means we already packed input_ids
if (
combined_embeddings is not None
and packed_seq_params is not None
and attention_mask is None
and cp_size > 1
and packed_seq_params.cu_seqlens_q_padded is not None
):
full_total_tokens = combined_embeddings.size(0)
assert full_total_tokens == input_ids.size(-1), f"{combined_embeddings.shape=} != {input_ids.shape=}"
index = get_thd_partitioned_indices(
packed_seq_params.cu_seqlens_q_padded,
full_total_tokens,
cp_size,
cp_rank,
)
# Split vision_mask by CP partition
vision_mask_local = vision_mask.index_select(1, index)
# deepstack_feature_lists
if deepstack_feature_lists is not None:
new_deepstack_feature_lists = []
for deepstack_visual_embed in deepstack_feature_lists:
tmp_embeddings = torch.zeros_like(combined_embeddings.transpose(0, 1))
tmp_embeddings[vision_mask] = deepstack_visual_embed
tmp_embeddings_thd = tmp_embeddings.index_select(1, index).squeeze(0).contiguous()
tmp_embeddings_thd = tmp_embeddings_thd[vision_mask_local.squeeze(0)].contiguous()
new_deepstack_feature_lists.append(tmp_embeddings_thd)
deepstack_feature_lists = new_deepstack_feature_lists
vision_mask = vision_mask_local
# combined_embeddings
combined_embeddings = combined_embeddings.index_select(0, index).contiguous()

# packed_seq_params is not None and attention_mask is not None:
# means we need packed input_ids in here
if packed_seq_params is not None and attention_mask is not None:
if packed_seq_params is not None:
input_ids_thd, _ = preprocess_packed_seqs(
input_ids, attention_mask, pre_process=True
)
Expand Down Expand Up @@ -442,7 +394,7 @@ def forward(
tp_size=mpu.get_tensor_model_parallel_world_size(),
tp_rank=mpu.get_tensor_model_parallel_rank(),
cp_size=cp_size,
cp_rank=cp_rank,
cp_rank=mpu.get_context_parallel_rank(),
sequence_parallel=self.config.sequence_parallel,
)
elif self.config.sequence_parallel: # THD and SP
Expand Down
7 changes: 4 additions & 3 deletions mbridge/models/qwen3_vl/rope_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ def forward(
Tensor: Embeddings after applying RoPE.
"""
seq = position_ids.to(device=self.inv_freq.device, dtype=self.inv_freq.dtype)
cp_group = kwargs.get("cp_group", parallel_state.get_context_parallel_group())

if self.seq_len_interpolation_factor is not None:
seq *= 1 / self.seq_len_interpolation_factor
Expand All @@ -137,12 +136,14 @@ def forward(
# shape (seq_length, bs, 1, 2 * dim)
emb = emb[..., None, :].transpose(0, 1).contiguous()
if (
cp_group.size() > 1
parallel_state.get_context_parallel_world_size() > 1
and not self.is_thd_format
):
# slice rotary_pos_emb along sequence dimension and select the parition of the current
# CP rank
emb = get_pos_emb_on_this_cp_rank(emb, 0, cp_group)
emb = get_pos_emb_on_this_cp_rank(
emb, 0, parallel_state.get_context_parallel_group()
)
return emb


Expand Down
Loading