Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 37 additions & 6 deletions megatron/core/tensor_parallel/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,14 @@ def forward(ctx, input, weight, bias, gradient_accumulation_fusion,
else:
total_input = input

output = torch.matmul(total_input, weight.t())
if len(weight.shape) == 2:
output = torch.matmul(total_input, weight.t())
elif len(weight.shape) == 3:
total_input = total_input.transpose(0, 1)
output = torch.matmul(total_input, weight.transpose(1, 2))
output = output.transpose(0, 1)
else:
raise RuntimeError(f"Unable to perform matmul between tensors of shape {total_input.shape} and {weight.shape}")
if bias is not None:
output = output + bias
return output
Expand Down Expand Up @@ -540,6 +547,7 @@ def __init__(self, input_size, output_size, *,
self._forward_impl = linear_with_grad_accumulation_and_async_allreduce



def forward(self,
input_: torch.Tensor,
weight: Optional[torch.Tensor] = None):
Expand All @@ -563,11 +571,14 @@ def forward(self,
weight = self.weight
else:
# Check the weight passed in is the correct shape
expected_shape = (self.output_size_per_partition, self.input_size)
if len(weight.shape) == 2:
expected_shape = (self.output_size_per_partition, self.input_size)
elif len(weight.shape) == 3:
expected_shape = (input_.shape[1], self.output_size_per_partition, self.input_size)
if weight.shape != expected_shape:
raise RuntimeError(f"supplied weight's shape is {tuple(weight.shape)}, "
f"not {expected_shape} as expected")

bias = self.bias if not self.skip_bias_add else None

if self.async_tensor_model_parallel_allreduce or \
Expand All @@ -586,12 +597,15 @@ def forward(self,
)
if self.gather_output:
# All-gather across the partitions.
assert not self.sequence_parallel
assert not self.sequence_parallel_enabled
output_parallel = output_parallel.contiguous()
output = gather_from_tensor_model_parallel_region(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias




class RowParallelLinear(torch.nn.Module):
Expand Down Expand Up @@ -694,11 +708,13 @@ def __init__(self, input_size: int, output_size: int, *,
self._forward_impl = linear_with_grad_accumulation_and_async_allreduce


def forward(self, input_):
def forward(self, input_,
weight: Optional[torch.Tensor] = None):
"""Forward of RowParallelLinear

Args:
input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
weight: 3D or 2D tensor of weights

Returns:
- output
Expand All @@ -710,10 +726,25 @@ def forward(self, input_):
else:
assert not self.sequence_parallel
input_parallel = scatter_to_tensor_model_parallel_region(input_)

if weight is None:
if self.weight is None:
raise RuntimeError("weight was not supplied to ColumnParallelLinear forward pass "
"and skip_weight_param_allocation is True.")
weight = self.weight
else:
# Check the weight passed in is the correct shape
if len(weight.shape) == 2:
expected_shape = (self.output_size_per_partition, self.input_size)
elif len(weight.shape) == 3:
expected_shape = (input_.shape[1], self.output_size_per_partition, self.input_size)
if weight.shape != expected_shape:
raise RuntimeError(f"supplied weight's shape is {tuple(weight.shape)}, "
f"not {expected_shape} as expected")
# Matrix multiply.
output_parallel = self._forward_impl(
input=input_parallel,
weight=self.weight,
weight=weight,
bias=None,
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
async_grad_allreduce=False,
Expand Down