From ce39fa77851840a0e7d418589dcf10dff2d3fb0e Mon Sep 17 00:00:00 2001 From: Abhinav Khattar Date: Wed, 12 Jul 2023 13:17:18 -0700 Subject: [PATCH 1/4] custom_wt --- megatron/core/tensor_parallel/layers.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/megatron/core/tensor_parallel/layers.py b/megatron/core/tensor_parallel/layers.py index 15e0fbb025a..59d896cc842 100644 --- a/megatron/core/tensor_parallel/layers.py +++ b/megatron/core/tensor_parallel/layers.py @@ -538,16 +538,34 @@ def __init__(self, input_size, output_size, *, ) - def forward(self, input_): + def forward(self, + input_: torch.Tensor, + weight: Optional[torch.Tensor] = None): """Forward of ColumnParallelLinear Args: input_: 3D tensor whose order of dimension is [sequence, batch, hidden] + weight (optional): weight tensor to use, compulsory when + skip_weight_param_allocation is True. + Returns: - output - bias + """ + if weight is None: + if self.weight is None: + raise RuntimeError("weight was not supplied to ColumnParallelLinear forward pass " + "and skip_weight_param_allocation is True.") + weight = self.weight + else: + # Check the weight passed in is the correct shape + expected_shape = (self.output_size_per_partition, self.input_size) + if weight.shape != expected_shape: + raise RuntimeError(f"supplied weight's shape is {tuple(weight.shape)}, " + f"not {expected_shape} as expected") + bias = self.bias if not self.skip_bias_add else None if self.async_tensor_model_parallel_allreduce or \ @@ -558,7 +576,7 @@ def forward(self, input_): # Matrix multiply. output_parallel = linear_with_grad_accumulation_and_async_allreduce( input=input_parallel, - weight=self.weight, + weight=weight, bias=bias, gradient_accumulation_fusion=self.gradient_accumulation_fusion, async_grad_allreduce=self.async_tensor_model_parallel_allreduce, From 701c062e5c1231cbe082309a4c70ff4d243913ee Mon Sep 17 00:00:00 2001 From: arendu Date: Mon, 17 Jul 2023 16:03:05 -0700 Subject: [PATCH 2/4] accept a batch of weights of shape (batch, n, m) Signed-off-by: arendu --- megatron/core/tensor_parallel/layers.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/megatron/core/tensor_parallel/layers.py b/megatron/core/tensor_parallel/layers.py index 59d896cc842..1047d3a3584 100644 --- a/megatron/core/tensor_parallel/layers.py +++ b/megatron/core/tensor_parallel/layers.py @@ -240,7 +240,14 @@ def forward(ctx, input, weight, bias, gradient_accumulation_fusion, else: total_input = input - output = torch.matmul(total_input, weight.t()) + if len(weight.shape) == 2: + output = torch.matmul(total_input, weight.t()) + elif len(weight.shape) == 3: + total_input = total_input.transpose(0, 1) + output = torch.matmul(total_input, weight.transpose(1, 2)) + output = output.transpose(0, 1) + else: + raise RuntimeError(f"Unable to perform matmul between tensors of shape {total_input.shape} and {weight.shape}") if bias is not None: output = output + bias return output @@ -561,7 +568,10 @@ def forward(self, weight = self.weight else: # Check the weight passed in is the correct shape - expected_shape = (self.output_size_per_partition, self.input_size) + if len(weight.shape) == 2: + expected_shape = (self.output_size_per_partition, self.input_size) + elif len(weight.shape) == 3: + expected_shape = (input_.shape[1], self.output_size_per_partition, self.input_size) if weight.shape != expected_shape: raise RuntimeError(f"supplied weight's shape is {tuple(weight.shape)}, " f"not {expected_shape} as expected") @@ -590,6 +600,8 @@ def forward(self, output = output_parallel output_bias = self.bias if self.skip_bias_add else None return output, output_bias + + class RowParallelLinear(torch.nn.Module): From 671cfe8fdab98b990884ad039e2daaa28c55b4ec Mon Sep 17 00:00:00 2001 From: arendu Date: Tue, 18 Jul 2023 12:06:45 -0700 Subject: [PATCH 3/4] added weight arg to forward for RowParallelLinear Signed-off-by: arendu --- megatron/core/tensor_parallel/layers.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/megatron/core/tensor_parallel/layers.py b/megatron/core/tensor_parallel/layers.py index 1047d3a3584..63e0c14e67b 100644 --- a/megatron/core/tensor_parallel/layers.py +++ b/megatron/core/tensor_parallel/layers.py @@ -706,11 +706,13 @@ def __init__(self, input_size, output_size, *, - def forward(self, input_): + def forward(self, input_, + weight: Optional[torch.Tensor] = None): """Forward of RowParallelLinear Args: input_: 3D tensor whose order of dimension is [sequence, batch, hidden] + weight: 3D or 2D tensor of weights Returns: - output @@ -722,10 +724,25 @@ def forward(self, input_): else: assert not self.sequence_parallel_enabled input_parallel = scatter_to_tensor_model_parallel_region(input_) + + if weight is None: + if self.weight is None: + raise RuntimeError("weight was not supplied to ColumnParallelLinear forward pass " + "and skip_weight_param_allocation is True.") + weight = self.weight + else: + # Check the weight passed in is the correct shape + if len(weight.shape) == 2: + expected_shape = (self.output_size_per_partition, self.input_size) + elif len(weight.shape) == 3: + expected_shape = (input_.shape[1], self.output_size_per_partition, self.input_size) + if weight.shape != expected_shape: + raise RuntimeError(f"supplied weight's shape is {tuple(weight.shape)}, " + f"not {expected_shape} as expected") # Matrix multiply. output_parallel = linear_with_grad_accumulation_and_async_allreduce( input=input_parallel, - weight=self.weight, + weight=weight, bias=None, gradient_accumulation_fusion=self.gradient_accumulation_fusion, async_grad_allreduce=False, From f62dcbba25cf0e20e51d509c878d47af6c31ed9d Mon Sep 17 00:00:00 2001 From: arendu Date: Fri, 21 Jul 2023 23:33:01 -0700 Subject: [PATCH 4/4] needed a contiguous call Signed-off-by: arendu --- megatron/core/tensor_parallel/layers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/megatron/core/tensor_parallel/layers.py b/megatron/core/tensor_parallel/layers.py index 63e0c14e67b..d7665ffd2e6 100644 --- a/megatron/core/tensor_parallel/layers.py +++ b/megatron/core/tensor_parallel/layers.py @@ -595,6 +595,7 @@ def forward(self, if self.gather_output: # All-gather across the partitions. assert not self.sequence_parallel_enabled + output_parallel = output_parallel.contiguous() output = gather_from_tensor_model_parallel_region(output_parallel) else: output = output_parallel