diff --git a/mbridge/models/qwen3_vl/model.py b/mbridge/models/qwen3_vl/model.py index 671e896..4a75c17 100644 --- a/mbridge/models/qwen3_vl/model.py +++ b/mbridge/models/qwen3_vl/model.py @@ -1,12 +1,10 @@ import logging -from typing import Optional import torch from megatron.core import InferenceParams, mpu, tensor_parallel from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.transformer import MegatronModule from megatron.core.transformer.spec_utils import ModuleSpec -from megatron.core.extensions.transformer_engine import get_thd_partitioned_indices from mbridge.core.util import ( AllGatherVisionEmbeddings, @@ -209,18 +207,18 @@ def forward( position_ids: torch.Tensor = None, # can set at dataset attention_mask: torch.Tensor = None, labels: torch.Tensor = None, - inference_params: Optional[InferenceParams] = None, - packed_seq_params: Optional[PackedSeqParams] = None, - extra_block_kwargs: Optional[dict] = None, - pixel_values: Optional[torch.Tensor] = None, - pixel_values_videos: Optional[torch.Tensor] = None, - image_grid_thw: Optional[torch.Tensor] = None, - video_grid_thw: Optional[torch.Tensor] = None, + inference_params: InferenceParams = None, + packed_seq_params: PackedSeqParams = None, + extra_block_kwargs: dict = None, + pixel_values: torch.Tensor = None, + pixel_values_videos: torch.Tensor = None, + image_grid_thw: torch.Tensor = None, + video_grid_thw: torch.Tensor = None, # can set at dataset image_input_mask: torch.Tensor = None, video_input_mask: torch.Tensor = None, - cp_img_num: Optional[list[int]] = None, - images_padded: Optional[list[bool]] = None, + cp_img_num: list[int] = None, + images_padded: list[bool] = None, **kwargs, ) -> torch.Tensor: """Forward function of the Qwen3VL model. @@ -250,14 +248,7 @@ def forward( vision_data = None vision_mask = None deepstack_feature_lists = None - if packed_seq_params is not None and packed_seq_params.cp_group is not None: - cp_group = packed_seq_params.cp_group - else: - cp_group = mpu.get_context_parallel_group() - cp_size = cp_group.size() - cp_rank = cp_group.rank() - - self.language_model.rotary_pos_emb.is_thd_format = packed_seq_params is not None + cp_size = mpu.get_context_parallel_world_size() if self.pre_process: # can reorganize_inputs at dataset @@ -293,7 +284,6 @@ def forward( self.square_merge_size, cp_img_num, images_padded, - group=cp_group, ) ) vision_grid_thw = collapse_thw(vision_grid_thw) @@ -321,13 +311,11 @@ def forward( vision_embeds = AllGatherVisionEmbeddings.apply( vision_embeds, seqlen_on_cp_ranks, - cp_group, ) for i in range(len(deepstack_feature_lists)): deepstack_feature_lists[i] = AllGatherVisionEmbeddings.apply( deepstack_feature_lists[i], seqlen_on_cp_ranks, - cp_group, ) combined_embeddings = self.language_model.embedding( @@ -348,43 +336,7 @@ def forward( combined_embeddings = split_data_cp_rank( combined_embeddings, cp_size, 0 ) - - # packed_seq_params is not None and attention_mask is None: - # means we already packed input_ids - if ( - combined_embeddings is not None - and packed_seq_params is not None - and attention_mask is None - and cp_size > 1 - and packed_seq_params.cu_seqlens_q_padded is not None - ): - full_total_tokens = combined_embeddings.size(0) - assert full_total_tokens == input_ids.size(-1), f"{combined_embeddings.shape=} != {input_ids.shape=}" - index = get_thd_partitioned_indices( - packed_seq_params.cu_seqlens_q_padded, - full_total_tokens, - cp_size, - cp_rank, - ) - # Split vision_mask by CP partition - vision_mask_local = vision_mask.index_select(1, index) - # deepstack_feature_lists - if deepstack_feature_lists is not None: - new_deepstack_feature_lists = [] - for deepstack_visual_embed in deepstack_feature_lists: - tmp_embeddings = torch.zeros_like(combined_embeddings.transpose(0, 1)) - tmp_embeddings[vision_mask] = deepstack_visual_embed - tmp_embeddings_thd = tmp_embeddings.index_select(1, index).squeeze(0).contiguous() - tmp_embeddings_thd = tmp_embeddings_thd[vision_mask_local.squeeze(0)].contiguous() - new_deepstack_feature_lists.append(tmp_embeddings_thd) - deepstack_feature_lists = new_deepstack_feature_lists - vision_mask = vision_mask_local - # combined_embeddings - combined_embeddings = combined_embeddings.index_select(0, index).contiguous() - - # packed_seq_params is not None and attention_mask is not None: - # means we need packed input_ids in here - if packed_seq_params is not None and attention_mask is not None: + if packed_seq_params is not None: input_ids_thd, _ = preprocess_packed_seqs( input_ids, attention_mask, pre_process=True ) @@ -442,7 +394,7 @@ def forward( tp_size=mpu.get_tensor_model_parallel_world_size(), tp_rank=mpu.get_tensor_model_parallel_rank(), cp_size=cp_size, - cp_rank=cp_rank, + cp_rank=mpu.get_context_parallel_rank(), sequence_parallel=self.config.sequence_parallel, ) elif self.config.sequence_parallel: # THD and SP diff --git a/mbridge/models/qwen3_vl/rope_utils.py b/mbridge/models/qwen3_vl/rope_utils.py index fe2b612..27a3676 100644 --- a/mbridge/models/qwen3_vl/rope_utils.py +++ b/mbridge/models/qwen3_vl/rope_utils.py @@ -118,7 +118,6 @@ def forward( Tensor: Embeddings after applying RoPE. """ seq = position_ids.to(device=self.inv_freq.device, dtype=self.inv_freq.dtype) - cp_group = kwargs.get("cp_group", parallel_state.get_context_parallel_group()) if self.seq_len_interpolation_factor is not None: seq *= 1 / self.seq_len_interpolation_factor @@ -137,12 +136,14 @@ def forward( # shape (seq_length, bs, 1, 2 * dim) emb = emb[..., None, :].transpose(0, 1).contiguous() if ( - cp_group.size() > 1 + parallel_state.get_context_parallel_world_size() > 1 and not self.is_thd_format ): # slice rotary_pos_emb along sequence dimension and select the parition of the current # CP rank - emb = get_pos_emb_on_this_cp_rank(emb, 0, cp_group) + emb = get_pos_emb_on_this_cp_rank( + emb, 0, parallel_state.get_context_parallel_group() + ) return emb