diff --git a/megatron/core/tensor_parallel/layers.py b/megatron/core/tensor_parallel/layers.py index 26436dbc8e8..00ff166ff1e 100644 --- a/megatron/core/tensor_parallel/layers.py +++ b/megatron/core/tensor_parallel/layers.py @@ -237,7 +237,14 @@ def forward(ctx, input, weight, bias, gradient_accumulation_fusion, else: total_input = input - output = torch.matmul(total_input, weight.t()) + if len(weight.shape) == 2: + output = torch.matmul(total_input, weight.t()) + elif len(weight.shape) == 3: + total_input = total_input.transpose(0, 1) + output = torch.matmul(total_input, weight.transpose(1, 2)) + output = output.transpose(0, 1) + else: + raise RuntimeError(f"Unable to perform matmul between tensors of shape {total_input.shape} and {weight.shape}") if bias is not None: output = output + bias return output @@ -540,6 +547,7 @@ def __init__(self, input_size, output_size, *, self._forward_impl = linear_with_grad_accumulation_and_async_allreduce + def forward(self, input_: torch.Tensor, weight: Optional[torch.Tensor] = None): @@ -563,11 +571,14 @@ def forward(self, weight = self.weight else: # Check the weight passed in is the correct shape - expected_shape = (self.output_size_per_partition, self.input_size) + if len(weight.shape) == 2: + expected_shape = (self.output_size_per_partition, self.input_size) + elif len(weight.shape) == 3: + expected_shape = (input_.shape[1], self.output_size_per_partition, self.input_size) if weight.shape != expected_shape: raise RuntimeError(f"supplied weight's shape is {tuple(weight.shape)}, " f"not {expected_shape} as expected") - + bias = self.bias if not self.skip_bias_add else None if self.async_tensor_model_parallel_allreduce or \ @@ -586,12 +597,15 @@ def forward(self, ) if self.gather_output: # All-gather across the partitions. - assert not self.sequence_parallel + assert not self.sequence_parallel_enabled + output_parallel = output_parallel.contiguous() output = gather_from_tensor_model_parallel_region(output_parallel) else: output = output_parallel output_bias = self.bias if self.skip_bias_add else None return output, output_bias + + class RowParallelLinear(torch.nn.Module): @@ -694,11 +708,13 @@ def __init__(self, input_size: int, output_size: int, *, self._forward_impl = linear_with_grad_accumulation_and_async_allreduce - def forward(self, input_): + def forward(self, input_, + weight: Optional[torch.Tensor] = None): """Forward of RowParallelLinear Args: input_: 3D tensor whose order of dimension is [sequence, batch, hidden] + weight: 3D or 2D tensor of weights Returns: - output @@ -710,10 +726,25 @@ def forward(self, input_): else: assert not self.sequence_parallel input_parallel = scatter_to_tensor_model_parallel_region(input_) + + if weight is None: + if self.weight is None: + raise RuntimeError("weight was not supplied to ColumnParallelLinear forward pass " + "and skip_weight_param_allocation is True.") + weight = self.weight + else: + # Check the weight passed in is the correct shape + if len(weight.shape) == 2: + expected_shape = (self.output_size_per_partition, self.input_size) + elif len(weight.shape) == 3: + expected_shape = (input_.shape[1], self.output_size_per_partition, self.input_size) + if weight.shape != expected_shape: + raise RuntimeError(f"supplied weight's shape is {tuple(weight.shape)}, " + f"not {expected_shape} as expected") # Matrix multiply. output_parallel = self._forward_impl( input=input_parallel, - weight=self.weight, + weight=weight, bias=None, gradient_accumulation_fusion=self.gradient_accumulation_fusion, async_grad_allreduce=False,